[feature](coagent)<增加antflow兼容和增加coagent demo>

This commit is contained in:
shanshi 2024-03-12 15:31:06 +08:00
parent c14b41ecec
commit 4d9b268a98
86 changed files with 3449 additions and 901 deletions

View File

@ -26,9 +26,12 @@ JUPYTER_WORK_PATH = os.environ.get("JUPYTER_WORK_PATH", None) or os.path.join(ex
WEB_CRAWL_PATH = os.environ.get("WEB_CRAWL_PATH", None) or os.path.join(executable_path, "knowledge_base")
# NEBULA_DATA存储路径
NELUBA_PATH = os.environ.get("NELUBA_PATH", None) or os.path.join(executable_path, "data/neluba_data")
NEBULA_PATH = os.environ.get("NEBULA_PATH", None) or os.path.join(executable_path, "data/nebula_data")
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]:
# CHROMA 存储路径
CHROMA_PERSISTENT_PATH = os.environ.get("CHROMA_PERSISTENT_PATH", None) or os.path.join(executable_path, "data/chroma_data")
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]:
if not os.path.exists(_path):
os.makedirs(_path, exist_ok=True)
@ -58,7 +61,8 @@ NEBULA_GRAPH_SERVER = {
}
# CHROMA CONFIG
CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
# CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
# CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/codefuse-chatbot-antcode/data/chroma_data'
# 默认向量库类型。可选faiss, milvus, pg.

View File

@ -7,7 +7,7 @@ from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.prompts.chat import ChatPromptTemplate
from coagent.llm_models import getChatModel, getChatModelFromConfig
from coagent.llm_models import getChatModelFromConfig
from coagent.chat.utils import History, wrap_done
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
# from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)

View File

@ -22,7 +22,7 @@ from coagent.connector.configs.prompts import CODE_PROMPT_TEMPLATE
from coagent.chat.utils import History, wrap_done
from coagent.utils import BaseResponse
from .base_chat import Chat
from coagent.llm_models import getChatModel, getChatModelFromConfig
from coagent.llm_models import getChatModelFromConfig
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
@ -67,6 +67,7 @@ class CodeChat(Chat):
embed_model_path=embed_config.embed_model_path,
embed_engine=embed_config.embed_engine,
model_device=embed_config.model_device,
embed_config=embed_config
)
context = codes_res['context']

View File

@ -12,7 +12,7 @@ from langchain.schema import (
# from configs.model_config import CODE_INTERPERT_TEMPLATE
from coagent.connector.configs.prompts import CODE_INTERPERT_TEMPLATE
from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig
from coagent.llm_models.openai_model import getChatModelFromConfig
from coagent.llm_models.llm_config import LLMConfig
@ -53,9 +53,15 @@ class CodeIntepreter:
message = CODE_INTERPERT_TEMPLATE.format(code=code)
messages.append(message)
chat_ress = chat_model.batch(messages)
try:
chat_ress = [chat_model(messages) for message in messages]
except:
chat_ress = chat_model.batch(messages)
for chat_res, code in zip(chat_ress, code_list):
res[code] = chat_res.content
try:
res[code] = chat_res.content
except:
res[code] = chat_res
return res

View File

@ -27,7 +27,7 @@ class DirCrawler:
logger.info(java_file_list)
for java_file in java_file_list:
with open(java_file) as f:
with open(java_file, encoding="utf-8") as f:
java_code = ''.join(f.readlines())
java_code_dict[java_file] = java_code
return java_code_dict

View File

@ -5,6 +5,7 @@
@time: 2023/11/21 下午2:35
@desc:
'''
import json
import time
from loguru import logger
from collections import defaultdict
@ -15,7 +16,7 @@ from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
from coagent.codechat.code_search.cypher_generator import CypherGenerator
from coagent.codechat.code_search.tagger import Tagger
from coagent.embeddings.get_embedding import get_embedding
from coagent.llm_models.llm_config import LLMConfig
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
@ -29,7 +30,8 @@ MAX_DISTANCE = 1000
class CodeSearch:
def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3):
def __init__(self, llm_config: LLMConfig, nh: NebulaHandler, ch: ChromaHandler, limit: int = 3,
local_graph_file_path: str = ''):
'''
init
@param nh: NebulaHandler
@ -37,7 +39,13 @@ class CodeSearch:
@param limit: limit of result
'''
self.llm_config = llm_config
self.nh = nh
if not self.nh:
with open(local_graph_file_path, 'r') as f:
self.graph = json.load(f)
self.ch = ch
self.limit = limit
@ -51,7 +59,7 @@ class CodeSearch:
tag_list = tagger.generate_tag_query(query)
logger.info(f'query tag={tag_list}')
# get all verticex
# get all vertices
vertex_list = self.nh.get_vertices().get('v', [])
vertex_vid_list = [i.as_node().get_id().as_string() for i in vertex_list]
@ -81,7 +89,7 @@ class CodeSearch:
# get most prominent package tag
package_score_dict = defaultdict(lambda: 0)
for vertex, score in vertex_score_dict.items():
for vertex, score in vertex_score_dict_final.items():
if '#' in vertex:
# get class name first
cypher = f'''MATCH (v1:class)-[e:contain]->(v2) WHERE id(v2) == '{vertex}' RETURN id(v1) as id;'''
@ -111,6 +119,53 @@ class CodeSearch:
logger.info(f'ids={ids}')
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
for vertex, score in package_score_tuple:
index = chroma_res['result']['ids'].index(vertex)
code_text = chroma_res['result']['metadatas'][index]['code_text']
res.append({
"vertex": vertex,
"code_text": code_text}
)
if len(res) >= self.limit:
break
# logger.info(f'retrival code={res}')
return res
def search_by_tag_by_graph(self, query: str):
'''
search code by tag with graph
@param query:
@return:
'''
tagger = Tagger()
tag_list = tagger.generate_tag_query(query)
logger.info(f'query tag={tag_list}')
# loop to get package node
package_score_dict = {}
for code, structure in self.graph.items():
score = 0
for class_name in structure['class_name_list']:
for tag in tag_list:
if tag.lower() in class_name.lower():
score += 1
for func_name_list in structure['func_name_dict'].values():
for func_name in func_name_list:
for tag in tag_list:
if tag.lower() in func_name.lower():
score += 1
package_score_dict[structure['pac_name']] = score
# get respective code
res = []
package_score_tuple = list(package_score_dict.items())
package_score_tuple.sort(key=lambda x: x[1], reverse=True)
ids = [i[0] for i in package_score_tuple]
logger.info(f'ids={ids}')
chroma_res = self.ch.get(ids=ids, include=['metadatas'])
# logger.info(chroma_res)
for vertex, score in package_score_tuple:
index = chroma_res['result']['ids'].index(vertex)
@ -121,23 +176,22 @@ class CodeSearch:
)
if len(res) >= self.limit:
break
logger.info(f'retrival code={res}')
# logger.info(f'retrival code={res}')
return res
def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu"):
def search_by_desciption(self, query: str, engine: str, model_path: str = "text2vec-base-chinese", embedding_device: str = "cpu", embed_config: EmbedConfig=None):
'''
search by perform sim search
@param query:
@return:
'''
query = query.replace(',', '')
query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device,)
query_emb = get_embedding(engine=engine, text_list=[query], model_path=model_path, embedding_device= embedding_device, embed_config=embed_config)
query_emb = query_emb[query]
query_embeddings = [query_emb]
query_result = self.ch.query(query_embeddings=query_embeddings, n_results=self.limit,
include=['metadatas', 'distances'])
logger.debug(query_result)
res = []
for idx, distance in enumerate(query_result['result']['distances'][0]):

View File

@ -8,7 +8,7 @@
from langchain import PromptTemplate
from loguru import logger
from coagent.llm_models.openai_model import getChatModel, getChatModelFromConfig
from coagent.llm_models.openai_model import getChatModelFromConfig
from coagent.llm_models.llm_config import LLMConfig
from coagent.utils.postprocess import replace_lt_gt
from langchain.schema import (

View File

@ -6,11 +6,10 @@
@desc:
'''
import time
import json
import os
from loguru import logger
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
# from configs.server_config import CHROMA_PERSISTENT_PATH
# from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
from coagent.embeddings.get_embedding import get_embedding
@ -18,12 +17,14 @@ from coagent.llm_models.llm_config import EmbedConfig
class CodeImporter:
def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler):
def __init__(self, codebase_name: str, embed_config: EmbedConfig, nh: NebulaHandler, ch: ChromaHandler,
local_graph_file_path: str):
self.codebase_name = codebase_name
# self.engine = engine
self.embed_config: EmbedConfig= embed_config
self.embed_config: EmbedConfig = embed_config
self.nh = nh
self.ch = ch
self.local_graph_file_path = local_graph_file_path
def import_code(self, static_analysis_res: dict, interpretation: dict, do_interpret: bool = True):
'''
@ -31,9 +32,14 @@ class CodeImporter:
@return:
'''
static_analysis_res = self.filter_out_vertex(static_analysis_res, interpretation)
logger.info(f'static_analysis_res={static_analysis_res}')
self.analysis_res_to_graph(static_analysis_res)
if self.nh:
self.analysis_res_to_graph(static_analysis_res)
else:
# persist to local dir
with open(self.local_graph_file_path, 'w') as f:
json.dump(static_analysis_res, f)
self.interpretation_to_db(static_analysis_res, interpretation, do_interpret)
def filter_out_vertex(self, static_analysis_res, interpretation):
@ -114,12 +120,12 @@ class CodeImporter:
# create vertex
for tag_name, value_dict in vertex_value_dict.items():
res = self.nh.insert_vertex(tag_name, value_dict)
logger.debug(res.error_msg())
# logger.debug(res.error_msg())
# create edge
for tag_name, value_dict in edge_value_dict.items():
res = self.nh.insert_edge(tag_name, value_dict)
logger.debug(res.error_msg())
# logger.debug(res.error_msg())
return
@ -132,7 +138,7 @@ class CodeImporter:
if do_interpret:
logger.info('start get embedding for interpretion')
interp_list = list(interpretation.values())
emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device)
emb = get_embedding(engine=self.embed_config.embed_engine, text_list=interp_list, model_path=self.embed_config.embed_model_path, embedding_device= self.embed_config.model_device, embed_config=self.embed_config)
logger.info('get embedding done')
else:
emb = {i: [0] for i in list(interpretation.values())}
@ -161,7 +167,7 @@ class CodeImporter:
# add documents to chroma
res = self.ch.add_data(ids=ids, embeddings=embeddings, documents=documents, metadatas=metadatas)
logger.debug(res)
# logger.debug(res)
def init_graph(self):
'''
@ -169,7 +175,7 @@ class CodeImporter:
@return:
'''
res = self.nh.create_space(space_name=self.codebase_name, vid_type='FIXED_STRING(1024)')
logger.debug(res.error_msg())
# logger.debug(res.error_msg())
time.sleep(5)
self.nh.set_space_name(self.codebase_name)
@ -179,29 +185,29 @@ class CodeImporter:
tag_name = 'package'
prop_dict = {}
res = self.nh.create_tag(tag_name, prop_dict)
logger.debug(res.error_msg())
# logger.debug(res.error_msg())
tag_name = 'class'
prop_dict = {}
res = self.nh.create_tag(tag_name, prop_dict)
logger.debug(res.error_msg())
# logger.debug(res.error_msg())
tag_name = 'method'
prop_dict = {}
res = self.nh.create_tag(tag_name, prop_dict)
logger.debug(res.error_msg())
# logger.debug(res.error_msg())
# create edge type
edge_type_name = 'contain'
prop_dict = {}
res = self.nh.create_edge_type(edge_type_name, prop_dict)
logger.debug(res.error_msg())
# logger.debug(res.error_msg())
# create edge type
edge_type_name = 'depend'
prop_dict = {}
res = self.nh.create_edge_type(edge_type_name, prop_dict)
logger.debug(res.error_msg())
# logger.debug(res.error_msg())
if __name__ == '__main__':

View File

@ -5,16 +5,15 @@
@time: 2023/11/21 下午2:25
@desc:
'''
import os
import time
import json
from typing import List
from loguru import logger
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
# from configs.server_config import CHROMA_PERSISTENT_PATH
# from configs.model_config import EMBEDDING_ENGINE
from coagent.base_configs.env_config import (
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
CHROMA_PERSISTENT_PATH
CHROMA_PERSISTENT_PATH, CB_ROOT_PATH
)
@ -35,7 +34,9 @@ class CodeBaseHandler:
language: str = 'java',
crawl_type: str = 'ZIP',
embed_config: EmbedConfig = EmbedConfig(),
llm_config: LLMConfig = LLMConfig()
llm_config: LLMConfig = LLMConfig(),
use_nh: bool = True,
local_graph_path: str = CB_ROOT_PATH
):
self.codebase_name = codebase_name
self.code_path = code_path
@ -43,11 +44,28 @@ class CodeBaseHandler:
self.crawl_type = crawl_type
self.embed_config = embed_config
self.llm_config = llm_config
self.local_graph_file_path = local_graph_path + os.sep + f'{self.codebase_name}_graph.json'
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=codebase_name)
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
time.sleep(1)
if use_nh:
try:
self.nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=codebase_name)
self.nh.add_host(NEBULA_HOST, NEBULA_STORAGED_PORT)
time.sleep(1)
except:
self.nh = None
try:
with open(self.local_graph_file_path, 'r') as f:
self.graph = json.load(f)
except:
pass
elif local_graph_path:
self.nh = None
try:
with open(self.local_graph_file_path, 'r') as f:
self.graph = json.load(f)
except:
pass
self.ch = ChromaHandler(path=CHROMA_PERSISTENT_PATH, collection_name=codebase_name)
@ -58,9 +76,10 @@ class CodeBaseHandler:
'''
# init graph to init tag and edge
code_importer = CodeImporter(embed_config=self.embed_config, codebase_name=self.codebase_name,
nh=self.nh, ch=self.ch)
code_importer.init_graph()
time.sleep(5)
nh=self.nh, ch=self.ch, local_graph_file_path=self.local_graph_file_path)
if self.nh:
code_importer.init_graph()
time.sleep(5)
# crawl code
st0 = time.time()
@ -71,7 +90,7 @@ class CodeBaseHandler:
# analyze code
logger.info('start analyze')
st1 = time.time()
code_analyzer = CodeAnalyzer(language=self.language, llm_config = self.llm_config)
code_analyzer = CodeAnalyzer(language=self.language, llm_config=self.llm_config)
static_analysis_res, interpretation = code_analyzer.analyze(code_dict, do_interpret=do_interpret)
logger.debug('analyze done, rt={}'.format(time.time() - st1))
@ -81,8 +100,12 @@ class CodeBaseHandler:
logger.debug('update codebase done, rt={}'.format(time.time() - st2))
# get KG info
stat = self.nh.get_stat()
vertices_num, edges_num = stat['vertices'], stat['edges']
if self.nh:
stat = self.nh.get_stat()
vertices_num, edges_num = stat['vertices'], stat['edges']
else:
vertices_num = 0
edges_num = 0
# get chroma info
file_num = self.ch.count()['result']
@ -95,7 +118,11 @@ class CodeBaseHandler:
@param codebase_name: name of codebase
@return:
'''
self.nh.drop_space(space_name=codebase_name)
if self.nh:
self.nh.drop_space(space_name=codebase_name)
elif self.local_graph_file_path and os.path.isfile(self.local_graph_file_path):
os.remove(self.local_graph_file_path)
self.ch.delete_collection(collection_name=codebase_name)
def crawl_code(self, zip_file=''):
@ -124,9 +151,15 @@ class CodeBaseHandler:
@param search_type: ['cypher', 'graph', 'vector']
@return:
'''
assert search_type in ['cypher', 'tag', 'description']
if self.nh:
assert search_type in ['cypher', 'tag', 'description']
else:
if search_type == 'tag':
search_type = 'tag_by_local_graph'
assert search_type in ['tag_by_local_graph', 'description']
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit)
code_search = CodeSearch(llm_config=self.llm_config, nh=self.nh, ch=self.ch, limit=limit,
local_graph_file_path=self.local_graph_file_path)
if search_type == 'cypher':
search_res = code_search.search_by_cypher(query=query)
@ -134,7 +167,11 @@ class CodeBaseHandler:
search_res = code_search.search_by_tag(query=query)
elif search_type == 'description':
search_res = code_search.search_by_desciption(
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path, embedding_device=self.embed_config.model_device)
query=query, engine=self.embed_config.embed_engine, model_path=self.embed_config.embed_model_path,
embedding_device=self.embed_config.model_device, embed_config=self.embed_config)
elif search_type == 'tag_by_local_graph':
search_res = code_search.search_by_tag_by_graph(query=query)
context, related_vertice = self.format_search_res(search_res, search_type)
return context, related_vertice
@ -160,6 +197,12 @@ class CodeBaseHandler:
for code in search_res:
context = context + code['code_text'] + '\n'
related_vertice.append(code['vertex'])
elif search_type == 'tag_by_local_graph':
context = ''
related_vertice = []
for code in search_res:
context = context + code['code_text'] + '\n'
related_vertice.append(code['vertex'])
elif search_type == 'description':
context = ''
related_vertice = []
@ -169,17 +212,63 @@ class CodeBaseHandler:
return context, related_vertice
def search_vertices(self, vertex_type="class") -> List[str]:
'''
通过 method/class 来搜索所有的节点
'''
vertices = []
if self.nh:
vertices = self.nh.get_all_vertices()
vertices = [str(v.as_node().get_id()) for v in vertices["v"] if vertex_type in v.as_node().tags()]
# for v in vertices["v"]:
# logger.debug(f"{v.as_node().get_id()}, {v.as_node().tags()}")
else:
if vertex_type == "class":
vertices = [str(class_name) for code, structure in self.graph.items() for class_name in structure['class_name_list']]
elif vertex_type == "method":
vertices = [
str(methods_name)
for code, structure in self.graph.items()
for methods_names in structure['func_name_dict'].values()
for methods_name in methods_names
]
# logger.debug(vertices)
return vertices
if __name__ == '__main__':
codebase_name = 'testing'
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH
from configs.server_config import SANDBOX_SERVER
LLM_MODEL = "gpt-3.5-turbo"
llm_config = LLMConfig(
model_name=LLM_MODEL, model_device="cpu", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
src_dir = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode'
embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
)
codebase_name = 'client_local'
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir')
use_nh = False
local_graph_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/code_base'
CHROMA_PERSISTENT_PATH = '/Users/bingxu/Desktop/工作/大模型/chatbot/Codefuse-chatbot-antcode/data/chroma_data'
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=local_graph_path,
llm_config=llm_config, embed_config=embed_config)
# test import code
# cbh.import_code(do_interpret=True)
# query = '使用不同的HTTP请求类型GET、POST、DELETE等来执行不同的操作'
# query = '代码中一共有多少个类'
# query = 'remove 这个函数是用来做什么的'
query = '有没有函数是从字符串中删除指定字符串的功能'
query = 'intercept 函数作用是什么'
search_type = 'graph'
search_type = 'description'
limit = 2
res = cbh.search_code(query, search_type, limit)
logger.debug(res)

View File

@ -0,0 +1,6 @@
from .base_action import BaseAction
__all__ = [
"BaseAction"
]

View File

@ -0,0 +1,16 @@
from langchain.schema import BaseRetriever, Document
class BaseAction:
def __init__(self, ):
pass
def step(self, ):
pass
def astep(self, ):
pass

View File

@ -4,25 +4,25 @@ import re, os
import copy
from loguru import logger
from langchain.schema import BaseRetriever
from coagent.connector.schema import (
Memory, Task, Role, Message, PromptField, LogVerboseEnum
)
from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
from coagent.connector.message_process import MessageUtils
from coagent.llm_models import getChatModel, getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig
from coagent.connector.prompt_manager import PromptManager
from coagent.llm_models import getExtraModel, LLMConfig, getChatModelFromConfig, EmbedConfig
from coagent.connector.prompt_manager.prompt_manager import PromptManager
from coagent.connector.memory_manager import LocalMemoryManager
from coagent.connector.utils import parse_section
# from configs.model_config import JUPYTER_WORK_PATH
# from configs.server_config import SANDBOX_SERVER
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
class BaseAgent:
def __init__(
self,
role: Role,
prompt_config: [PromptField],
prompt_config: List[PromptField],
prompt_manager_type: str = "PromptManager",
task: Task = None,
memory: Memory = None,
@ -33,8 +33,11 @@ class BaseAgent:
llm_config: LLMConfig = None,
embed_config: EmbedConfig = None,
sandbox_server: dict = {},
jupyter_work_path: str = "",
kb_root_path: str = "",
jupyter_work_path: str = JUPYTER_WORK_PATH,
kb_root_path: str = KB_ROOT_PATH,
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0"
):
@ -43,7 +46,7 @@ class BaseAgent:
self.sandbox_server = sandbox_server
self.jupyter_work_path = jupyter_work_path
self.kb_root_path = kb_root_path
self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
self.message_utils = MessageUtils(role, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
self.memory = self.init_history(memory)
self.llm_config: LLMConfig = llm_config
self.embed_config: EmbedConfig = embed_config
@ -82,12 +85,8 @@ class BaseAgent:
llm_config=self.embed_config
)
memory_manager.append(query)
memory_pool = memory_manager.current_memory
else:
memory_pool = memory_manager.current_memory
memory_pool = memory_manager.get_memory_pool(query.user_name)
logger.debug(f"memory_pool: {memory_pool}")
prompt = self.prompt_manager.generate_full_prompt(
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool)
content = self.llm.predict(prompt)
@ -99,6 +98,7 @@ class BaseAgent:
logger.info(f"{self.role.role_name} content: {content}")
output_message = Message(
user_name=query.user_name,
role_name=self.role.role_name,
role_type="assistant", #self.role.role_type,
role_content=content,
@ -151,10 +151,7 @@ class BaseAgent:
self.memory = self.init_history()
def create_llm_engine(self, llm_config: LLMConfig = None, temperature=0.2, stop=None):
if llm_config is None:
return getChatModel(temperature=temperature, stop=stop)
else:
return getChatModelFromConfig(llm_config=llm_config)
return getChatModelFromConfig(llm_config=llm_config)
def registry_actions(self, actions):
'''registry llm's actions'''
@ -212,171 +209,3 @@ class BaseAgent:
def get_memory_str(self, content_key="role_content"):
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
def create_prompt(
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_pool: Memory=None, prompt_mamnger=None) -> str:
'''
prompt engineer, contains role\task\tools\docs\memory
'''
#
doc_infos = self.create_doc_prompt(query)
code_infos = self.create_codedoc_prompt(query)
#
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
task_prompt = self.create_task_prompt(query)
background_prompt = self.create_background_prompt(background, control_key="step_content")
history_prompt = self.create_history_prompt(history)
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
# extra_system_prompt = self.role.role_prompt
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
#
memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_pool)
memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']])
# input_query = query.input_query
# # logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
# # logger.debug(f"{self.role.role_name} input_query: {input_query}")
# # logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
# # logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
# if "**Context:**" in self.role.role_prompt:
# # logger.debug(f"parsed_output_list: {query.parsed_output_list}")
# # input_query = "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''"
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
# # context = history_prompt or '""'
# # logger.debug(f"parsed_output_list: {t}")
# prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query})
# else:
# prompt += "\n" + PLAN_PROMPT_INPUT.format(**{"query": input_query})
task = query.task or self.task
if task_prompt is not None:
prompt += "\n" + task.task_prompt
DocInfos = ""
if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
DocInfos += f"\nDocument Information: {doc_infos}"
if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
DocInfos += f"\nCodeBase Infomation: {code_infos}"
# if selfmemory_prompt:
# prompt += "\n" + selfmemory_prompt
# if background_prompt:
# prompt += "\n" + background_prompt
# if history_prompt:
# prompt += "\n" + history_prompt
input_query = query.input_query
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
# extra_system_prompt = self.role.role_prompt
input_keys = parse_section(self.role.role_prompt, 'Input Format')
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
prompt += "\n" + BEGIN_PROMPT_INPUT
for input_key in input_keys:
if input_key == "Origin Query":
prompt += "\n**Origin Query:**\n" + query.origin_query
elif input_key == "Context":
context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
if history:
context = history_prompt + "\n" + context
if not context:
context = "there is no context"
if self.focus_agents and memory_pool_select_by_agent_key_context:
context = memory_pool_select_by_agent_key_context
prompt += "\n**Context:**\n" + context + "\n" + input_query
elif input_key == "DocInfos":
if DocInfos:
prompt += "\n**DocInfos:**\n" + DocInfos
else:
prompt += "\n**DocInfos:**\n" + "Empty"
elif input_key == "Question":
prompt += "\n**Question:**\n" + input_query
# if "**Context:**" in self.role.role_prompt:
# # logger.debug(f"parsed_output_list: {query.parsed_output_list}")
# # input_query = "'''" + "\n".join([f"###{k}###\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k]) + "'''"
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
# if history:
# context = history_prompt + "\n" + context
# if not context:
# context = "there is no context"
# # logger.debug(f"parsed_output_list: {t}")
# if "DocInfos" in prompt:
# prompt += "\n" + QUERY_CONTEXT_DOC_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos})
# else:
# prompt += "\n" + QUERY_CONTEXT_PROMPT_INPUT.format(**{"context": context, "query": query.origin_query, "DocInfos": DocInfos})
# else:
# prompt += "\n" + BASE_PROMPT_INPUT.format(**{"query": input_query})
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
while "{{" in prompt or "}}" in prompt:
prompt = prompt.replace("{{", "{")
prompt = prompt.replace("}}", "}")
# logger.debug(f"{self.role.role_name} prompt: {prompt}")
return prompt
def create_doc_prompt(self, message: Message) -> str:
''''''
db_docs = message.db_docs
search_docs = message.search_docs
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs])
return doc_infos or "不存在知识库辅助信息"
def create_codedoc_prompt(self, message: Message) -> str:
''''''
code_docs = message.code_docs
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
return doc_infos or "不存在代码库辅助信息"
def create_tools_prompt(self, message: Message) -> str:
tools = message.tools
tool_strings = []
tools_descs = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
tools_descs.append(f"{tool.name}: {tool.description}")
formatted_tools = "\n".join(tool_strings)
tools_desc_str = "\n".join(tools_descs)
tool_names = ", ".join([tool.name for tool in tools])
return formatted_tools, tool_names, tools_desc_str
def create_task_prompt(self, message: Message) -> str:
task = message.task or self.task
return "\n任务目标: " + task.task_prompt if task is not None else None
def create_background_prompt(self, background: Memory, control_key="role_content") -> str:
background_message = None if background is None else background.to_str_messages(content_key=control_key)
# logger.debug(f"background_message: {background_message}")
if background_message:
background_message = re.sub("}", "}}", re.sub("{", "{{", background_message))
return "\n背景信息: " + background_message if background_message else None
def create_history_prompt(self, history: Memory, control_key="role_content") -> str:
history_message = None if history is None else history.to_str_messages(content_key=control_key)
if history_message:
history_message = re.sub("}", "}}", re.sub("{", "{{", history_message))
return "\n补充对话信息: " + history_message if history_message else None
def create_selfmemory_prompt(self, selfmemory: Memory, control_key="role_content") -> str:
selfmemory_message = None if selfmemory is None else selfmemory.to_str_messages(content_key=control_key)
if selfmemory_message:
selfmemory_message = re.sub("}", "}}", re.sub("{", "{{", selfmemory_message))
return "\n补充自身对话信息: " + selfmemory_message if selfmemory_message else None

View File

@ -2,14 +2,15 @@ from typing import List, Union
import copy
from loguru import logger
from langchain.schema import BaseRetriever
from coagent.connector.schema import (
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
)
from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
from coagent.llm_models import LLMConfig, EmbedConfig
from coagent.connector.memory_manager import LocalMemoryManager
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
from .base_agent import BaseAgent
@ -17,7 +18,7 @@ class ExecutorAgent(BaseAgent):
def __init__(
self,
role: Role,
prompt_config: [PromptField],
prompt_config: List[PromptField],
prompt_manager_type: str= "PromptManager",
task: Task = None,
memory: Memory = None,
@ -28,14 +29,17 @@ class ExecutorAgent(BaseAgent):
llm_config: LLMConfig = None,
embed_config: EmbedConfig = None,
sandbox_server: dict = {},
jupyter_work_path: str = "",
kb_root_path: str = "",
jupyter_work_path: str = JUPYTER_WORK_PATH,
kb_root_path: str = KB_ROOT_PATH,
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0"
):
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
jupyter_work_path, kb_root_path, log_verbose
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
)
self.do_all_task = True # run all tasks
@ -45,6 +49,7 @@ class ExecutorAgent(BaseAgent):
task_executor_memory = Memory(messages=[])
# insert query
output_message = Message(
user_name=query.user_name,
role_name=self.role.role_name,
role_type="assistant", #self.role.role_type,
role_content=query.input_query,
@ -115,7 +120,7 @@ class ExecutorAgent(BaseAgent):
history: Memory, background: Memory, memory_manager: BaseMemoryManager,
task_memory: Memory) -> Union[Message, Memory]:
'''execute the llm predict by created prompt'''
memory_pool = memory_manager.current_memory
memory_pool = memory_manager.get_memory_pool(query.user_name)
prompt = self.prompt_manager.generate_full_prompt(
previous_agent_message=query, agent_long_term_memory=self_memory, ui_history=history, chain_summary_messages=background, memory_pool=memory_pool,
task_memory=task_memory)

View File

@ -3,23 +3,23 @@ import traceback
import copy
from loguru import logger
from langchain.schema import BaseRetriever
from coagent.connector.schema import (
Memory, Task, Env, Role, Message, ActionStatus, PromptField, LogVerboseEnum
)
from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.configs.agent_config import REACT_PROMPT_INPUT
from coagent.llm_models import LLMConfig, EmbedConfig
from .base_agent import BaseAgent
from coagent.connector.memory_manager import LocalMemoryManager
from coagent.connector.prompt_manager import PromptManager
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
class ReactAgent(BaseAgent):
def __init__(
self,
role: Role,
prompt_config: [PromptField],
prompt_config: List[PromptField],
prompt_manager_type: str = "PromptManager",
task: Task = None,
memory: Memory = None,
@ -30,14 +30,17 @@ class ReactAgent(BaseAgent):
llm_config: LLMConfig = None,
embed_config: EmbedConfig = None,
sandbox_server: dict = {},
jupyter_work_path: str = "",
kb_root_path: str = "",
jupyter_work_path: str = JUPYTER_WORK_PATH,
kb_root_path: str = KB_ROOT_PATH,
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0"
):
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
jupyter_work_path, kb_root_path, log_verbose
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
)
def step(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
@ -52,6 +55,7 @@ class ReactAgent(BaseAgent):
react_memory = Memory(messages=[])
# insert query
output_message = Message(
user_name=query.user_name,
role_name=self.role.role_name,
role_type="assistant", #self.role.role_type,
role_content=query.input_query,
@ -84,9 +88,7 @@ class ReactAgent(BaseAgent):
llm_config=self.embed_config
)
memory_manager.append(query)
memory_pool = memory_manager.current_memory
else:
memory_pool = memory_manager.current_memory
memory_pool = memory_manager.get_memory_pool(query_c.user_name)
prompt = self.prompt_manager.generate_full_prompt(
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=react_memory,
@ -142,82 +144,4 @@ class ReactAgent(BaseAgent):
title = f"<<<<{self.role.role_name}'s prompt>>>>"
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
# def create_prompt(
# self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_manager: BaseMemoryManager= None,
# prompt_mamnger=None) -> str:
# prompt_mamnger = PromptManager()
# prompt_mamnger.register_standard_fields()
# # input_keys = parse_section(self.role.role_prompt, 'Agent Profile')
# data_dict = {
# "agent_profile": extract_section(self.role.role_prompt, 'Agent Profile'),
# "tool_information": query.tools,
# "session_records": memory_manager,
# "reference_documents": query,
# "output_format": extract_section(self.role.role_prompt, 'Response Output Format'),
# "response": "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()]),
# }
# # logger.debug(memory_pool)
# return prompt_mamnger.generate_full_prompt(data_dict)
def create_prompt(
self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, react_memory: Memory = None, memory_pool: Memory= None,
prompt_mamnger=None) -> str:
'''
role\task\tools\docs\memory
'''
#
doc_infos = self.create_doc_prompt(query)
code_infos = self.create_codedoc_prompt(query)
#
formatted_tools, tool_names, _ = self.create_tools_prompt(query)
task_prompt = self.create_task_prompt(query)
background_prompt = self.create_background_prompt(background)
history_prompt = self.create_history_prompt(history)
selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
#
# extra_system_prompt = self.role.role_prompt
prompt = self.role.role_prompt.format(**{"formatted_tools": formatted_tools, "tool_names": tool_names})
# react 流程是自身迭代过程,另外二次触发的是需要作为历史对话信息
# input_query = react_memory.to_tuple_messages(content_key="step_content")
# # input_query = query.input_query + "\n" + "\n".join([f"{v}" for k, v in input_query if v])
# input_query = "\n".join([f"{v}" for k, v in input_query if v])
input_query = "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
# logger.debug(f"input_query: {input_query}")
prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
task = query.task or self.task
# if task_prompt is not None:
# prompt += "\n" + task.task_prompt
# if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
# prompt += f"\n知识库信息: {doc_infos}"
# if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
# prompt += f"\n代码库信息: {code_infos}"
# if background_prompt:
# prompt += "\n" + background_prompt
# if history_prompt:
# prompt += "\n" + history_prompt
# if selfmemory_prompt:
# prompt += "\n" + selfmemory_prompt
# logger.debug(f"{self.role.role_name} extra_system_prompt: {self.role.role_prompt}")
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
# logger.debug(f"{self.role.role_name} doc_infos: {doc_infos}")
# logger.debug(f"{self.role.role_name} tool_names: {tool_names}")
# prompt += "\n" + REACT_PROMPT_INPUT.format(**{"query": input_query})
# prompt = extra_system_prompt.format(**{"query": input_query, "doc_infos": doc_infos, "formatted_tools": formatted_tools, "tool_names": tool_names})
while "{{" in prompt or "}}" in prompt:
prompt = prompt.replace("{{", "{")
prompt = prompt.replace("}}", "}")
return prompt

View File

@ -3,13 +3,15 @@ import copy
import random
from loguru import logger
from langchain.schema import BaseRetriever
from coagent.connector.schema import (
Memory, Task, Role, Message, PromptField, LogVerboseEnum
)
from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.configs.prompts import BEGIN_PROMPT_INPUT
from coagent.connector.memory_manager import LocalMemoryManager
from coagent.llm_models import LLMConfig, EmbedConfig
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
from .base_agent import BaseAgent
@ -30,14 +32,17 @@ class SelectorAgent(BaseAgent):
llm_config: LLMConfig = None,
embed_config: EmbedConfig = None,
sandbox_server: dict = {},
jupyter_work_path: str = "",
kb_root_path: str = "",
jupyter_work_path: str = JUPYTER_WORK_PATH,
kb_root_path: str = KB_ROOT_PATH,
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0"
):
super().__init__(role, prompt_config, prompt_manager_type, task, memory, chat_turn,
focus_agents, focus_message_keys, llm_config, embed_config, sandbox_server,
jupyter_work_path, kb_root_path, log_verbose
jupyter_work_path, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose
)
self.group_agents = group_agents
@ -56,9 +61,8 @@ class SelectorAgent(BaseAgent):
llm_config=self.embed_config
)
memory_manager.append(query)
memory_pool = memory_manager.current_memory
else:
memory_pool = memory_manager.current_memory
memory_pool = memory_manager.get_memory_pool(query_c.user_name)
prompt = self.prompt_manager.generate_full_prompt(
previous_agent_message=query_c, agent_long_term_memory=self.memory, ui_history=history, chain_summary_messages=background, react_memory=None,
memory_pool=memory_pool, agents=self.group_agents)
@ -90,6 +94,9 @@ class SelectorAgent(BaseAgent):
for agent in self.group_agents:
if agent.role.role_name == select_message.parsed_output.get("Role", ""):
break
# 把除了role以外的信息传给下一个agent
query_c.parsed_output.update({k:v for k,v in select_message.parsed_output.items() if k!="Role"})
for output_message in agent.astep(query_c, history, background=background, memory_manager=memory_manager):
yield output_message or select_message
# update self_memory
@ -103,6 +110,7 @@ class SelectorAgent(BaseAgent):
memory_manager.append(output_message)
select_message.parsed_output = output_message.parsed_output
select_message.spec_parsed_output.update(output_message.spec_parsed_output)
select_message.parsed_output_list.extend(output_message.parsed_output_list)
yield select_message
@ -114,77 +122,4 @@ class SelectorAgent(BaseAgent):
print("#"*len(title) + f"\n{title}\n"+ "#"*len(title)+ f"\n\n{prompt}\n\n")
for agent in self.group_agents:
agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager)
# def create_prompt(
# self, query: Message, memory: Memory =None, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager=None, prompt_mamnger=None) -> str:
# '''
# role\task\tools\docs\memory
# '''
# #
# doc_infos = self.create_doc_prompt(query)
# code_infos = self.create_codedoc_prompt(query)
# #
# formatted_tools, tool_names, tools_descs = self.create_tools_prompt(query)
# agent_names, agents = self.create_agent_names()
# task_prompt = self.create_task_prompt(query)
# background_prompt = self.create_background_prompt(background)
# history_prompt = self.create_history_prompt(history)
# selfmemory_prompt = self.create_selfmemory_prompt(memory, control_key="step_content")
# DocInfos = ""
# if doc_infos is not None and doc_infos!="" and doc_infos!="不存在知识库辅助信息":
# DocInfos += f"\nDocument Information: {doc_infos}"
# if code_infos is not None and code_infos!="" and code_infos!="不存在代码库辅助信息":
# DocInfos += f"\nCodeBase Infomation: {code_infos}"
# input_query = query.input_query
# logger.debug(f"{self.role.role_name} input_query: {input_query}")
# prompt = self.role.role_prompt.format(**{"agent_names": agent_names, "agents": agents, "formatted_tools": tools_descs, "tool_names": tool_names})
# #
# memory_pool_select_by_agent_key = self.select_memory_by_agent_key(memory_manager.current_memory)
# memory_pool_select_by_agent_key_context = '\n\n'.join([f"*{k}*\n{v}" for parsed_output in memory_pool_select_by_agent_key.get_parserd_output_list() for k, v in parsed_output.items() if k not in ['Action Status']])
# input_keys = parse_section(self.role.role_prompt, 'Input Format')
# #
# prompt += "\n" + BEGIN_PROMPT_INPUT
# for input_key in input_keys:
# if input_key == "Origin Query":
# prompt += "\n**Origin Query:**\n" + query.origin_query
# elif input_key == "Context":
# context = "\n".join([f"*{k}*\n{v}" for i in query.parsed_output_list for k,v in i.items() if "Action Status" !=k])
# if history:
# context = history_prompt + "\n" + context
# if not context:
# context = "there is no context"
# if self.focus_agents and memory_pool_select_by_agent_key_context:
# context = memory_pool_select_by_agent_key_context
# prompt += "\n**Context:**\n" + context + "\n" + input_query
# elif input_key == "DocInfos":
# prompt += "\n**DocInfos:**\n" + DocInfos
# elif input_key == "Question":
# prompt += "\n**Question:**\n" + input_query
# while "{{" in prompt or "}}" in prompt:
# prompt = prompt.replace("{{", "{")
# prompt = prompt.replace("}}", "}")
# # logger.debug(f"{self.role.role_name} prompt: {prompt}")
# return prompt
# def create_agent_names(self):
# random.shuffle(self.group_agents)
# agent_names = ", ".join([f'{agent.role.role_name}' for agent in self.group_agents])
# agent_descs = []
# for agent in self.group_agents:
# role_desc = agent.role.role_prompt.split("####")[1]
# while "\n\n" in role_desc:
# role_desc = role_desc.replace("\n\n", "\n")
# role_desc = role_desc.replace("\n", ",")
# agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
# return agent_names, "\n".join(agent_descs)
agent.pre_print(query=query, history=history, background=background, memory_manager=memory_manager)

View File

@ -0,0 +1,7 @@
from .flow import AgentFlow, PhaseFlow, ChainFlow
__all__ = [
"AgentFlow", "PhaseFlow", "ChainFlow"
]

View File

@ -0,0 +1,255 @@
import importlib
from typing import List, Union, Dict, Any
from loguru import logger
import os
from langchain.embeddings.base import Embeddings
from langchain.agents import Tool
from langchain.llms.base import BaseLLM, LLM
from coagent.retrieval.base_retrieval import IMRertrieval
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.connector.phase import BasePhase
from coagent.connector.agents import BaseAgent
from coagent.connector.chains import BaseChain
from coagent.connector.schema import Message, Role, PromptField, ChainConfig
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
class AgentFlow:
def __init__(
self,
role_name: str,
agent_type: str,
role_type: str = "assistant",
agent_index: int = 0,
role_prompt: str = "",
prompt_config: List[Dict[str, Any]] = [],
prompt_manager_type: str = "PromptManager",
chat_turn: int = 3,
focus_agents: List[str] = [],
focus_messages: List[str] = [],
embeddings: Embeddings = None,
llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
**kwargs
):
self.role_type = role_type
self.role_name = role_name
self.agent_type = agent_type
self.role_prompt = role_prompt
self.agent_index = agent_index
self.prompt_config = prompt_config
self.prompt_manager_type = prompt_manager_type
self.chat_turn = chat_turn
self.focus_agents = focus_agents
self.focus_messages = focus_messages
self.embeddings = embeddings
self.llm = llm
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
# self.build_config()
# self.build_agent()
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
def build_agent(self,
embeddings: Embeddings = None, llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
):
# 可注册个性化的agent仅通过start_action和end_action来注册
# class ExtraAgent(BaseAgent):
# def start_action_step(self, message: Message) -> Message:
# pass
# def end_action_step(self, message: Message) -> Message:
# pass
# agent_module = importlib.import_module("coagent.connector.agents")
# setattr(agent_module, 'extraAgent', ExtraAgent)
# 可注册个性化的prompt组装方式
# class CodeRetrievalPM(PromptManager):
# def handle_code_packages(self, **kwargs) -> str:
# if 'previous_agent_message' not in kwargs:
# return ""
# previous_agent_message: Message = kwargs['previous_agent_message']
# # 由于两个agent共用了同一个manager所以临时性处理
# vertices = previous_agent_message.customed_kargs.get("RelatedVerticesRetrivalRes", {}).get("vertices", [])
# return ", ".join([str(v) for v in vertices])
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
# agent实例化
agent_module = importlib.import_module("coagent.connector.agents")
baseAgent: BaseAgent = getattr(agent_module, self.agent_type)
role = Role(
role_type=self.agent_type, role_name=self.role_name,
agent_type=self.agent_type, role_prompt=self.role_prompt,
)
self.build_config(embeddings, llm)
self.agent = baseAgent(
role=role,
prompt_config = [PromptField(**config) for config in self.prompt_config],
prompt_manager_type=self.prompt_manager_type,
chat_turn=self.chat_turn,
focus_agents=self.focus_agents,
focus_message_keys=self.focus_messages,
llm_config=self.llm_config,
embed_config=self.embed_config,
doc_retrieval=doc_retrieval or self.doc_retrieval,
code_retrieval=code_retrieval or self.code_retrieval,
search_retrieval=search_retrieval or self.search_retrieval,
)
class ChainFlow:
def __init__(
self,
chain_name: str,
chain_index: int = 0,
agent_flows: List[AgentFlow] = [],
chat_turn: int = 5,
do_checker: bool = False,
embeddings: Embeddings = None,
llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
# chain_type: str = "BaseChain",
**kwargs
):
self.agent_flows = sorted(agent_flows, key=lambda x:x.agent_index)
self.chat_turn = chat_turn
self.do_checker = do_checker
self.chain_name = chain_name
self.chain_index = chain_index
self.chain_type = "BaseChain"
self.embeddings = embeddings
self.llm = llm
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
# self.build_config()
# self.build_chain()
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
def build_chain(self,
embeddings: Embeddings = None, llm: BaseLLM = None,
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
):
# chain 实例化
chain_module = importlib.import_module("coagent.connector.chains")
baseChain: BaseChain = getattr(chain_module, self.chain_type)
agent_names = [agent_flow.role_name for agent_flow in self.agent_flows]
chain_config = ChainConfig(chain_name=self.chain_name, agents=agent_names, do_checker=self.do_checker, chat_turn=self.chat_turn)
# agent 实例化
self.build_config(embeddings, llm)
for agent_flow in self.agent_flows:
agent_flow.build_agent(embeddings, llm)
self.chain = baseChain(
chain_config,
[agent_flow.agent for agent_flow in self.agent_flows],
embed_config=self.embed_config,
llm_config=self.llm_config,
doc_retrieval=doc_retrieval or self.doc_retrieval,
code_retrieval=code_retrieval or self.code_retrieval,
search_retrieval=search_retrieval or self.search_retrieval,
)
class PhaseFlow:
def __init__(
self,
phase_name: str,
chain_flows: List[ChainFlow],
embeddings: Embeddings = None,
llm: BaseLLM = None,
tools: List[Tool] = [],
doc_retrieval: IMRertrieval = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
**kwargs
):
self.phase_name = phase_name
self.chain_flows = sorted(chain_flows, key=lambda x:x.chain_index)
self.phase_type = "BasePhase"
self.tools = tools
self.embeddings = embeddings
self.llm = llm
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
# self.build_config()
self.build_phase()
def __call__(self, params: dict) -> str:
# tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
# query_content = "帮我确认下127.0.0.1这个服务器的在10点是否存在异常请帮我判断一下"
try:
logger.info(f"params: {params}")
query_content = params.get("query") or params.get("input")
search_type = params.get("search_type")
query = Message(
role_name="human", role_type="user", tools=self.tools,
role_content=query_content, input_query=query_content, origin_query=query_content,
cb_search_type=search_type,
)
# phase.pre_print(query)
output_message, output_memory = self.phase.step(query)
output_content = "\n\n".join((output_memory.to_str_messages(return_all=True, content_key="parsed_output_list").split("\n\n")[1:])) or output_message.role_content
return output_content
except Exception as e:
logger.exception(e)
return f"Error {e}"
def build_config(self, embeddings: Embeddings = None, llm: BaseLLM = None):
self.llm_config = LLMConfig(model_name="test", llm=self.llm or llm)
self.embed_config = EmbedConfig(embed_model="test", langchain_embeddings=self.embeddings or embeddings)
def build_phase(self, embeddings: Embeddings = None, llm: BaseLLM = None):
# phase 实例化
phase_module = importlib.import_module("coagent.connector.phase")
basePhase: BasePhase = getattr(phase_module, self.phase_type)
# chain 实例化
self.build_config(self.embeddings or embeddings, self.llm or llm)
os.environ["log_verbose"] = "2"
for chain_flow in self.chain_flows:
chain_flow.build_chain(
self.embeddings or embeddings, self.llm or llm,
self.doc_retrieval, self.code_retrieval, self.search_retrieval
)
self.phase: BasePhase = basePhase(
phase_name=self.phase_name,
chains=[chain_flow.chain for chain_flow in self.chain_flows],
embed_config=self.embed_config,
llm_config=self.llm_config,
doc_retrieval=self.doc_retrieval,
code_retrieval=self.code_retrieval,
search_retrieval=self.search_retrieval
)

View File

@ -1,9 +1,10 @@
from typing import List
from typing import List, Tuple, Union
from loguru import logger
import copy, os
from coagent.connector.agents import BaseAgent
from langchain.schema import BaseRetriever
from coagent.connector.agents import BaseAgent
from coagent.connector.schema import (
Memory, Role, Message, ActionStatus, ChainConfig,
load_role_configs
@ -11,31 +12,32 @@ from coagent.connector.schema import (
from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.message_process import MessageUtils
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
from coagent.connector.configs.agent_config import AGETN_CONFIGS
role_configs = load_role_configs(AGETN_CONFIGS)
# from configs.model_config import JUPYTER_WORK_PATH
# from configs.server_config import SANDBOX_SERVER
class BaseChain:
def __init__(
self,
# chainConfig: ChainConfig,
chainConfig: ChainConfig,
agents: List[BaseAgent],
chat_turn: int = 1,
do_checker: bool = False,
# chat_turn: int = 1,
# do_checker: bool = False,
sandbox_server: dict = {},
jupyter_work_path: str = "",
kb_root_path: str = "",
jupyter_work_path: str = JUPYTER_WORK_PATH,
kb_root_path: str = KB_ROOT_PATH,
llm_config: LLMConfig = LLMConfig(),
embed_config: EmbedConfig = None,
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0"
) -> None:
# self.chainConfig = chainConfig
self.chainConfig = chainConfig
self.agents: List[BaseAgent] = agents
self.chat_turn = chat_turn
self.do_checker = do_checker
self.chat_turn = chainConfig.chat_turn
self.do_checker = chainConfig.do_checker
self.sandbox_server = sandbox_server
self.jupyter_work_path = jupyter_work_path
self.llm_config = llm_config
@ -45,9 +47,11 @@ class BaseChain:
task = None, memory = None,
llm_config=llm_config, embed_config=embed_config,
sandbox_server=sandbox_server, jupyter_work_path=jupyter_work_path,
kb_root_path=kb_root_path
kb_root_path=kb_root_path,
doc_retrieval=doc_retrieval, code_retrieval=code_retrieval,
search_retrieval=search_retrieval
)
self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
self.messageUtils = MessageUtils(None, sandbox_server, self.jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
# all memory created by agent until instance deleted
self.global_memory = Memory(messages=[])
@ -62,13 +66,16 @@ class BaseChain:
for agent in self.agents:
agent.pre_print(query, history, background=background, memory_manager=memory_manager)
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Message:
def astep(self, query: Message, history: Memory = None, background: Memory = None, memory_manager: BaseMemoryManager = None) -> Tuple[Message, Memory]:
'''execute chain'''
local_memory = Memory(messages=[])
input_message = copy.deepcopy(query)
step_nums = copy.deepcopy(self.chat_turn)
check_message = None
# if input_message not in memory_manager:
# memory_manager.append(input_message)
self.global_memory.append(input_message)
# local_memory.append(input_message)
while step_nums > 0:
@ -78,7 +85,7 @@ class BaseChain:
yield output_message, local_memory + output_message
output_message = self.messageUtils.inherit_extrainfo(input_message, output_message)
# according the output to choose one action for code_content or tool_content
output_message = self.messageUtils.parser(output_message)
# output_message = self.messageUtils.parser(output_message)
yield output_message, local_memory + output_message
# output_message = self.step_router(output_message)
input_message = output_message

View File

@ -1,9 +1,10 @@
from .agent_config import AGETN_CONFIGS
from .chain_config import CHAIN_CONFIGS
from .phase_config import PHASE_CONFIGS
from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
from .prompt_config import *
__all__ = [
"AGETN_CONFIGS", "CHAIN_CONFIGS", "PHASE_CONFIGS",
"BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS"
"BASE_PROMPT_CONFIGS", "EXECUTOR_PROMPT_CONFIGS", "SELECTOR_PROMPT_CONFIGS", "BASE_NOTOOLPROMPT_CONFIGS",
"CODE2DOC_GROUP_PROMPT_CONFIGS", "CODE2DOC_PROMPT_CONFIGS", "CODE2TESTS_PROMPT_CONFIGS"
]

View File

@ -1,19 +1,21 @@
from enum import Enum
from .prompts import (
REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT,
RECOGNIZE_INTENTION_PROMPT,
CHECKER_TEMPLATE_PROMPT,
CONV_SUMMARY_PROMPT,
QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT,
EXECUTOR_TEMPLATE_PROMPT,
REFINE_TEMPLATE_PROMPT,
SELECTOR_AGENT_TEMPLATE_PROMPT,
PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT,
PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT,
REACT_TEMPLATE_PROMPT,
REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT
)
from .prompt_config import BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
from .prompts import *
# from .prompts import (
# REACT_PROMPT_INPUT, CHECK_PROMPT_INPUT, EXECUTOR_PROMPT_INPUT, CONTEXT_PROMPT_INPUT, QUERY_CONTEXT_PROMPT_INPUT,PLAN_PROMPT_INPUT,
# RECOGNIZE_INTENTION_PROMPT,
# CHECKER_TEMPLATE_PROMPT,
# CONV_SUMMARY_PROMPT,
# QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT,
# EXECUTOR_TEMPLATE_PROMPT,
# REFINE_TEMPLATE_PROMPT,
# SELECTOR_AGENT_TEMPLATE_PROMPT,
# PLANNER_TEMPLATE_PROMPT, GENERAL_PLANNER_PROMPT, DATA_PLANNER_PROMPT, TOOL_PLANNER_PROMPT,
# PRD_WRITER_METAGPT_PROMPT, DESIGN_WRITER_METAGPT_PROMPT, TASK_WRITER_METAGPT_PROMPT, CODE_WRITER_METAGPT_PROMPT,
# REACT_TEMPLATE_PROMPT,
# REACT_TOOL_PROMPT, REACT_CODE_PROMPT, REACT_TOOL_AND_CODE_PLANNER_PROMPT, REACT_TOOL_AND_CODE_PROMPT
# )
from .prompt_config import *
# BASE_PROMPT_CONFIGS, EXECUTOR_PROMPT_CONFIGS, SELECTOR_PROMPT_CONFIGS, BASE_NOTOOLPROMPT_CONFIGS
@ -261,4 +263,68 @@ AGETN_CONFIGS = {
"focus_agents": ["metaGPT_DESIGN", "metaGPT_TASK"],
"focus_message_keys": [],
},
"class2Docer": {
"role": {
"role_prompt": Class2Doc_PROMPT,
"role_type": "assistant",
"role_name": "class2Docer",
"role_desc": "",
"agent_type": "CodeGenDocer"
},
"prompt_config": CODE2DOC_PROMPT_CONFIGS,
"prompt_manager_type": "Code2DocPM",
"chat_turn": 1,
"focus_agents": [],
"focus_message_keys": [],
},
"func2Docer": {
"role": {
"role_prompt": Func2Doc_PROMPT,
"role_type": "assistant",
"role_name": "func2Docer",
"role_desc": "",
"agent_type": "CodeGenDocer"
},
"prompt_config": CODE2DOC_PROMPT_CONFIGS,
"prompt_manager_type": "Code2DocPM",
"chat_turn": 1,
"focus_agents": [],
"focus_message_keys": [],
},
"code2DocsGrouper": {
"role": {
"role_prompt": Code2DocGroup_PROMPT,
"role_type": "assistant",
"role_name": "code2DocsGrouper",
"role_desc": "",
"agent_type": "SelectorAgent"
},
"prompt_config": CODE2DOC_GROUP_PROMPT_CONFIGS,
"group_agents": ["class2Docer", "func2Docer"],
"chat_turn": 1,
},
"Code2TestJudger": {
"role": {
"role_prompt": judgeCode2Tests_PROMPT,
"role_type": "assistant",
"role_name": "Code2TestJudger",
"role_desc": "",
"agent_type": "CodeRetrieval"
},
"prompt_config": CODE2TESTS_PROMPT_CONFIGS,
"prompt_manager_type": "CodeRetrievalPM",
"chat_turn": 1,
},
"code2Tests": {
"role": {
"role_prompt": code2Tests_PROMPT,
"role_type": "assistant",
"role_name": "code2Tests",
"role_desc": "",
"agent_type": "CodeRetrieval"
},
"prompt_config": CODE2TESTS_PROMPT_CONFIGS,
"prompt_manager_type": "CodeRetrievalPM",
"chat_turn": 1,
},
}

View File

@ -123,5 +123,21 @@ CHAIN_CONFIGS = {
"chat_turn": 1,
"do_checker": False,
"chain_prompt": ""
},
"code2DocsGroupChain": {
"chain_name": "code2DocsGroupChain",
"chain_type": "BaseChain",
"agents": ["code2DocsGrouper"],
"chat_turn": 1,
"do_checker": False,
"chain_prompt": ""
},
"code2TestsChain": {
"chain_name": "code2TestsChain",
"chain_type": "BaseChain",
"agents": ["Code2TestJudger", "code2Tests"],
"chat_turn": 1,
"do_checker": False,
"chain_prompt": ""
}
}

View File

@ -14,44 +14,24 @@ PHASE_CONFIGS = {
"phase_name": "docChatPhase",
"phase_type": "BasePhase",
"chains": ["docChatChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": True,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
},
"searchChatPhase": {
"phase_name": "searchChatPhase",
"phase_type": "BasePhase",
"chains": ["searchChatChain"],
"do_summary": False,
"do_search": True,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
},
"codeChatPhase": {
"phase_name": "codeChatPhase",
"phase_type": "BasePhase",
"chains": ["codeChatChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": True,
"do_tool_retrieval": False,
"do_using_tool": False
},
"toolReactPhase": {
"phase_name": "toolReactPhase",
"phase_type": "BasePhase",
"chains": ["toolReactChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": True
},
"codeReactPhase": {
@ -59,55 +39,36 @@ PHASE_CONFIGS = {
"phase_type": "BasePhase",
# "chains": ["codePlannerChain", "codeReactChain"],
"chains": ["planChain", "codeReactChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
},
"codeToolReactPhase": {
"phase_name": "codeToolReactPhase",
"phase_type": "BasePhase",
"chains": ["codeToolPlanChain", "codeToolReactChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": True
},
"baseTaskPhase": {
"phase_name": "baseTaskPhase",
"phase_type": "BasePhase",
"chains": ["planChain", "executorChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
},
"metagpt_code_devlop": {
"phase_name": "metagpt_code_devlop",
"phase_type": "BasePhase",
"chains": ["metagptChain",],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
},
"baseGroupPhase": {
"phase_name": "baseGroupPhase",
"phase_type": "BasePhase",
"chains": ["baseGroupChain"],
"do_summary": False,
"do_search": False,
"do_doc_retrieval": False,
"do_code_retrieval": False,
"do_tool_retrieval": False,
"do_using_tool": False
},
"code2DocsGroup": {
"phase_name": "code2DocsGroup",
"phase_type": "BasePhase",
"chains": ["code2DocsGroupChain"],
},
"code2Tests": {
"phase_name": "code2Tests",
"phase_type": "BasePhase",
"chains": ["code2TestsChain"],
}
}

View File

@ -40,4 +40,41 @@ SELECTOR_PROMPT_CONFIGS = [
{"field_name": 'current_plan', "function_name": 'handle_current_plan'},
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
]
]
CODE2DOC_GROUP_PROMPT_CONFIGS = [
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
{"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False},
# {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
# {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
{"field_name": 'session_records', "function_name": 'handle_session_records'},
{"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
{"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
]
CODE2DOC_PROMPT_CONFIGS = [
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
# {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
# {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
{"field_name": 'session_records', "function_name": 'handle_session_records'},
{"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
{"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
]
CODE2TESTS_PROMPT_CONFIGS = [
{"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
{"field_name": 'context_placeholder', "function_name": '', "is_context": True},
{"field_name": 'session_records', "function_name": 'handle_session_records'},
{"field_name": 'code_snippet', "function_name": 'handle_code_snippet'},
{"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""},
{"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
{"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
]

View File

@ -14,7 +14,8 @@ from .qa_template_prompt import QA_PROMPT, CODE_QA_PROMPT, QA_TEMPLATE_PROMPT, C
from .executor_template_prompt import EXECUTOR_TEMPLATE_PROMPT
from .refine_template_prompt import REFINE_TEMPLATE_PROMPT
from .code2doc_template_prompt import Code2DocGroup_PROMPT, Class2Doc_PROMPT, Func2Doc_PROMPT
from .code2test_template_prompt import code2Tests_PROMPT, judgeCode2Tests_PROMPT
from .agent_selector_template_prompt import SELECTOR_AGENT_TEMPLATE_PROMPT
from .react_template_prompt import REACT_TEMPLATE_PROMPT
@ -37,5 +38,7 @@ __all__ = [
"SELECTOR_AGENT_TEMPLATE_PROMPT",
"PLANNER_TEMPLATE_PROMPT", "GENERAL_PLANNER_PROMPT", "DATA_PLANNER_PROMPT", "TOOL_PLANNER_PROMPT",
"REACT_TEMPLATE_PROMPT",
"REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT"
"REACT_CODE_PROMPT", "REACT_TOOL_PROMPT", "REACT_TOOL_AND_CODE_PROMPT", "REACT_TOOL_AND_CODE_PLANNER_PROMPT",
"Code2DocGroup_PROMPT", "Class2Doc_PROMPT", "Func2Doc_PROMPT",
"code2Tests_PROMPT", "judgeCode2Tests_PROMPT"
]

View File

@ -0,0 +1,95 @@
Code2DocGroup_PROMPT = """#### Agent Profile
Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
ATTENTION: response carefully referenced "Response Output Format" in format.
#### Input Format
#### Response Output Format
**Code Path:** Extract the paths for the class/method/function that need to be addressed from the context
**Role:** Select the role from agent names
"""
Class2Doc_PROMPT = """#### Agent Profile
As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters.
Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters.
ATTENTION: response carefully in "Response Output Format".
#### Input Format
**Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation.
#### Response Output Format
**Class Base:** Specify the base class or interface from which the current class extends, if any.
**Class Description:** Offer a brief description of the class's purpose and functionality.
**Init Parameters:** List each parameter from construct. For each parameter, provide:
- `param`: The parameter name
- `param_description`: A concise explanation of the parameter's purpose.
- `param_type`: The data type of the parameter, if explicitly defined.
```json
[
{
"param": "parameter_name",
"param_description": "A brief description of what this parameter is used for.",
"param_type": "The data type of the parameter"
},
...
]
```
If no parameter for construct, return
```json
[]
```
"""
Func2Doc_PROMPT = """#### Agent Profile
You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation.
ATTENTION: response carefully in "Response Output Format".
#### Input Format
**Code Path:** Provide the code path of the function or method you wish to document.
This name will be used to identify and extract the relevant details from the code snippet provided.
**Code Snippet:** A segment of code that contains the function or method to be documented.
#### Response Output Format
**Class Description:** Offer a brief description of the method(function)'s purpose and functionality.
**Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide:
- `param`: The parameter name
- `param_description`: A concise explanation of the parameter's purpose.
- `param_type`: The data type of the parameter, if explicitly defined.
```json
[
{
"param": "parameter_name",
"param_description": "A brief description of what this parameter is used for.",
"param_type": "The data type of the parameter"
},
...
]
```
If no parameter for function/method, return
```json
[]
```
**Return Value Description:** Describe what the function/method returns upon completion.
**Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void).
"""

View File

@ -0,0 +1,65 @@
judgeCode2Tests_PROMPT = """#### Agent Profile
When determining the necessity of writing test cases for a given code snippet,
it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets),
in addition to considering these critical factors:
1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness.
2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc.,
it's more likely to harbor bugs, and thus test cases should be written.
If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring.
3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested.
Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality.
4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required.
5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important.
6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities.
#### Input Format
**Code Snippet:** the initial Code or objective that the user wanted to achieve
**Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on.
Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality.
#### Response Output Format
**Action Status:** Set to 'finished' or 'continued'.
If set to 'finished', the code snippet does not warrant the generation of a test case.
If set to 'continued', the code snippet necessitates the creation of a test case.
**REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale.
"""
code2Tests_PROMPT = """#### Agent Profile
As an agent specializing in software quality assurance,
your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.
This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.
Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.
Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance.
ATTENTION: response carefully referenced "Response Output Format" in format.
Each test case should include:
1. clear description of the test purpose.
2. The input values or conditions for the test.
3. The expected outcome or assertion for the test.
4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case.
5. these test code should have package and import
#### Input Format
**Code Snippet:** the initial Code or objective that the user wanted to achieve
**Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet.
#### Response Output Format
**SaveFileName:** construct a local file name based on Question and Context, such as
```java
package/class.java
```
**Test Code:** generate the test code for the current Code Snippet.
```java
...
```
"""

View File

@ -1,5 +1,5 @@
from abc import abstractmethod, ABC
from typing import List
from typing import List, Dict
import os, sys, copy, json
from jieba.analyse import extract_tags
from collections import Counter
@ -10,12 +10,13 @@ from langchain.docstore.document import Document
from .schema import Memory, Message
from coagent.service.service_factory import KBServiceFactory
from coagent.llm_models import getChatModel, getChatModelFromConfig
from coagent.llm_models import getChatModelFromConfig
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.embeddings.utils import load_embeddings_from_path
from coagent.utils.common_utils import save_to_json_file, read_json_file, addMinutesToTime
from coagent.connector.configs.prompts import CONV_SUMMARY_PROMPT_SPEC
from coagent.orm import table_init
from coagent.base_configs.env_config import KB_ROOT_PATH
# from configs.model_config import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, SCORE_THRESHOLD
# from configs.model_config import embedding_model_dict
@ -70,16 +71,22 @@ class BaseMemoryManager(ABC):
self.unique_name = unique_name
self.memory_type = memory_type
self.do_init = do_init
self.current_memory = Memory(messages=[])
self.recall_memory = Memory(messages=[])
self.summary_memory = Memory(messages=[])
# self.current_memory = Memory(messages=[])
# self.recall_memory = Memory(messages=[])
# self.summary_memory = Memory(messages=[])
self.current_memory_dict: Dict[str, Memory] = {}
self.recall_memory_dict: Dict[str, Memory] = {}
self.summary_memory_dict: Dict[str, Memory] = {}
self.save_message_keys = [
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
self.init_vb()
def init_vb(self):
def re_init(self, do_init: bool=False):
self.init_vb()
def init_vb(self, do_init: bool=None):
"""
Initializes the vb.
"""
@ -135,13 +142,15 @@ class BaseMemoryManager(ABC):
"""
pass
def save_to_vs(self, embed_model="", embed_device=""):
def save_to_vs(self, ):
"""
Saves the memory to the vector space.
"""
pass
Args:
- embed_model: A string representing the embedding model. Default is EMBEDDING_MODEL.
- embed_device: A string representing the embedding device. Default is EMBEDDING_DEVICE.
def get_memory_pool(self, user_name: str, ):
"""
return memory_pool
"""
pass
@ -230,7 +239,7 @@ class LocalMemoryManager(BaseMemoryManager):
unique_name: str = "default",
memory_type: str = "recall",
do_init: bool = False,
kb_root_path: str = "",
kb_root_path: str = KB_ROOT_PATH,
):
self.user_name = user_name
self.unique_name = unique_name
@ -239,16 +248,22 @@ class LocalMemoryManager(BaseMemoryManager):
self.kb_root_path = kb_root_path
self.embed_config: EmbedConfig = embed_config
self.llm_config: LLMConfig = llm_config
self.current_memory = Memory(messages=[])
self.recall_memory = Memory(messages=[])
self.summary_memory = Memory(messages=[])
# self.current_memory = Memory(messages=[])
# self.recall_memory = Memory(messages=[])
# self.summary_memory = Memory(messages=[])
self.current_memory_dict: Dict[str, Memory] = {}
self.recall_memory_dict: Dict[str, Memory] = {}
self.summary_memory_dict: Dict[str, Memory] = {}
self.save_message_keys = [
'chat_index', 'role_name', 'role_type', 'role_prompt', 'input_query', 'origin_query',
'datetime', 'role_content', 'step_content', 'parsed_output', 'spec_parsed_output', 'parsed_output_list',
'task', 'db_docs', 'code_docs', 'search_docs', 'phase_name', 'chain_name', 'customed_kargs']
self.init_vb()
def init_vb(self):
def re_init(self, do_init: bool=False):
self.init_vb(do_init)
def init_vb(self, do_init: bool=None):
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
# default to recreate a new vb
table_init()
@ -256,31 +271,37 @@ class LocalMemoryManager(BaseMemoryManager):
if vb:
status = vb.clear_vs()
if not self.do_init:
check_do_init = do_init if do_init else self.do_init
if not check_do_init:
self.load(self.kb_root_path)
self.save_to_vs()
def append(self, message: Message):
self.recall_memory.append(message)
self.check_user_name(message.user_name)
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
self.recall_memory_dict[uuid_name].append(message)
#
if message.role_type == "summary":
self.summary_memory.append(message)
self.summary_memory_dict[uuid_name].append(message)
else:
self.current_memory.append(message)
self.current_memory_dict[uuid_name].append(message)
self.save(self.kb_root_path)
self.save_new_to_vs([message])
def extend(self, memory: Memory):
self.recall_memory.extend(memory)
self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"]))
self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"]))
self.save(self.kb_root_path)
self.save_new_to_vs(memory.messages)
# def extend(self, memory: Memory):
# self.recall_memory.extend(memory)
# self.current_memory.extend(self.recall_memory.filter_by_role_type(["summary"]))
# self.summary_memory.extend(self.recall_memory.select_by_role_type(["summary"]))
# self.save(self.kb_root_path)
# self.save_new_to_vs(memory.messages)
def save(self, save_dir: str = "./"):
file_path = os.path.join(save_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
memory_messages = self.recall_memory.dict()
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
memory_messages = self.recall_memory_dict[uuid_name].dict()
memory_messages = {k: [
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
for vv in v ]
@ -291,18 +312,28 @@ class LocalMemoryManager(BaseMemoryManager):
def load(self, load_dir: str = "./") -> Memory:
file_path = os.path.join(load_dir, f"{self.user_name}/{self.unique_name}/{self.memory_type}/converation.jsonl")
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
if os.path.exists(file_path):
self.recall_memory = Memory(**read_json_file(file_path))
self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"]))
self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"]))
# self.recall_memory = Memory(**read_json_file(file_path))
# self.current_memory = Memory(messages=self.recall_memory.filter_by_role_type(["summary"]))
# self.summary_memory = Memory(messages=self.recall_memory.select_by_role_type(["summary"]))
recall_memory = Memory(**read_json_file(file_path))
self.recall_memory_dict[uuid_name] = recall_memory
self.current_memory_dict[uuid_name] = Memory(messages=recall_memory.filter_by_role_type(["summary"]))
self.summary_memory_dict[uuid_name] = Memory(messages=recall_memory.select_by_role_type(["summary"]))
else:
self.recall_memory_dict[uuid_name] = Memory(messages=[])
self.current_memory_dict[uuid_name] = Memory(messages=[])
self.summary_memory_dict[uuid_name] = Memory(messages=[])
def save_new_to_vs(self, messages: List[Message]):
if self.embed_config:
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
# default to faiss, todo: add new vstype
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,)
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
messages = [
{k: v for k, v in m.dict().items() if k in self.save_message_keys}
for m in messages]
@ -311,23 +342,26 @@ class LocalMemoryManager(BaseMemoryManager):
vb.do_add_doc(docs, embeddings)
def save_to_vs(self):
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
# default to recreate a new vb
vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path)
if vb:
status = vb.clear_vs()
# create_kb(vb_name, "faiss", embed_model)
'''only after load'''
if self.embed_config:
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
# default to recreate a new vb
vb = KBServiceFactory.get_service_by_name(vb_name, self.embed_config, self.kb_root_path)
if vb:
status = vb.clear_vs()
# create_kb(vb_name, "faiss", embed_model)
# default to faiss, todo: add new vstype
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device,)
messages = self.recall_memory.dict()
messages = [
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
for k, v in messages.items() for vv in v]
docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages]
docs = [Document(**doc) for doc in docs]
vb.do_add_doc(docs, embeddings)
# default to faiss, todo: add new vstype
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
embeddings = load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
messages = self.recall_memory_dict[uuid_name].dict()
messages = [
{kkk: vvv for kkk, vvv in vv.items() if kkk in self.save_message_keys}
for k, v in messages.items() for vv in v]
docs = [{"page_content": m["step_content"] or m["role_content"] or m["input_query"] or m["origin_query"], "metadata": m} for m in messages]
docs = [Document(**doc) for doc in docs]
vb.do_add_doc(docs, embeddings)
# def load_from_vs(self, embed_model=EMBEDDING_MODEL) -> Memory:
# vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
@ -338,7 +372,12 @@ class LocalMemoryManager(BaseMemoryManager):
# docs = vb.get_all_documents()
# print(docs)
def router_retrieval(self, text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]:
def get_memory_pool(self, user_name: str, ):
self.check_user_name(user_name)
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
return self.recall_memory_dict[uuid_name]
def router_retrieval(self, user_name: str = "default", text: str=None, datetime: str = None, n=5, top_k=5, retrieval_type: str = "embedding", **kwargs) -> List[Message]:
retrieval_func_dict = {
"embedding": self.embedding_retrieval, "text": self.text_retrieval, "datetime": self.datetime_retrieval
}
@ -356,20 +395,22 @@ class LocalMemoryManager(BaseMemoryManager):
#
return retrieval_func(**params)
def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, **kwargs) -> List[Message]:
def embedding_retrieval(self, text: str, top_k=1, score_threshold=1.0, user_name: str = "default", **kwargs) -> List[Message]:
if text is None: return []
vb_name = f"{self.user_name}/{self.unique_name}/{self.memory_type}"
vb_name = f"{user_name}/{self.unique_name}/{self.memory_type}"
vb = KBServiceFactory.get_service(vb_name, "faiss", self.embed_config, self.kb_root_path)
docs = vb.search_docs(text, top_k=top_k, score_threshold=score_threshold)
return [Message(**doc.metadata) for doc, score in docs]
def text_retrieval(self, text: str, **kwargs) -> List[Message]:
def text_retrieval(self, text: str, user_name: str = "default", **kwargs) -> List[Message]:
if text is None: return []
return self._text_retrieval_from_cache(self.recall_memory.messages, text, score_threshold=0.3, topK=5, **kwargs)
uuid_name = "_".join([user_name, self.unique_name, self.memory_type])
return self._text_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, text, score_threshold=0.3, topK=5, **kwargs)
def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, **kwargs) -> List[Message]:
def datetime_retrieval(self, datetime: str, text: str = None, n: int = 5, user_name: str = "default", **kwargs) -> List[Message]:
if datetime is None: return []
return self._datetime_retrieval_from_cache(self.recall_memory.messages, datetime, text, n, **kwargs)
uuid_name = "_".join([user_name, self.unique_name, self.memory_type])
return self._datetime_retrieval_from_cache(self.recall_memory_dict[uuid_name].messages, datetime, text, n, **kwargs)
def _text_retrieval_from_cache(self, messages: List[Message], text: str = None, score_threshold=0.3, topK=5, tag_topK=5, **kwargs) -> List[Message]:
keywords = extract_tags(text, topK=tag_topK)
@ -427,4 +468,18 @@ class LocalMemoryManager(BaseMemoryManager):
)
summary_message.parsed_output_list.append({"summary": content})
newest_messages.insert(0, summary_message)
return newest_messages
return newest_messages
def check_user_name(self, user_name: str):
# logger.debug(f"self.user_name is {self.user_name}")
if user_name != self.user_name:
self.user_name = user_name
self.init_vb()
uuid_name = "_".join([self.user_name, self.unique_name, self.memory_type])
if uuid_name not in self.recall_memory_dict:
self.recall_memory_dict[uuid_name] = Memory(messages=[])
self.current_memory_dict[uuid_name] = Memory(messages=[])
self.summary_memory_dict[uuid_name] = Memory(messages=[])
# logger.debug(f"self.user_name is {self.user_name}")

View File

@ -1,16 +1,19 @@
import re, traceback, uuid, copy, json, os
from typing import Union
from loguru import logger
from langchain.schema import BaseRetriever
# from configs.server_config import SANDBOX_SERVER
# from configs.model_config import JUPYTER_WORK_PATH
from coagent.connector.schema import (
Memory, Role, Message, ActionStatus, CodeDoc, Doc, LogVerboseEnum
)
from coagent.retrieval.base_retrieval import IMRertrieval
from coagent.connector.memory_manager import BaseMemoryManager
from coagent.tools import DDGSTool, DocRetrieval, CodeRetrieval
from coagent.sandbox import PyCodeBox, CodeBoxResponse
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
from coagent.base_configs.env_config import JUPYTER_WORK_PATH
from .utils import parse_dict_to_dict, parse_text_to_dict
@ -19,10 +22,13 @@ class MessageUtils:
self,
role: Role = None,
sandbox_server: dict = {},
jupyter_work_path: str = "./",
jupyter_work_path: str = JUPYTER_WORK_PATH,
embed_config: EmbedConfig = None,
llm_config: LLMConfig = None,
kb_root_path: str = "",
doc_retrieval: Union[BaseRetriever, IMRertrieval] = None,
code_retrieval: IMRertrieval = None,
search_retrieval: IMRertrieval = None,
log_verbose: str = "0"
) -> None:
self.role = role
@ -31,6 +37,9 @@ class MessageUtils:
self.embed_config = embed_config
self.llm_config = llm_config
self.kb_root_path = kb_root_path
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
self.codebox = PyCodeBox(
remote_url=self.sandbox_server.get("url", "http://127.0.0.1:5050"),
remote_ip=self.sandbox_server.get("host", "http://127.0.0.1"),
@ -44,6 +53,7 @@ class MessageUtils:
self.log_verbose = os.environ.get("log_verbose", "0") or log_verbose
def inherit_extrainfo(self, input_message: Message, output_message: Message):
output_message.user_name = input_message.user_name
output_message.db_docs = input_message.db_docs
output_message.search_docs = input_message.search_docs
output_message.code_docs = input_message.code_docs
@ -116,18 +126,45 @@ class MessageUtils:
knowledge_basename = message.doc_engine_name
top_k = message.top_k
score_threshold = message.score_threshold
if knowledge_basename:
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path)
if self.doc_retrieval:
if isinstance(self.doc_retrieval, BaseRetriever):
docs = self.doc_retrieval.get_relevant_documents(query)
else:
# docs = self.doc_retrieval.run(query, search_top=message.top_k, score_threshold=message.score_threshold,)
docs = self.doc_retrieval.run(query)
docs = [
{"index": idx, "snippet": doc.page_content, "title": doc.metadata.get("title_prefix", ""), "link": doc.metadata.get("url", "")}
for idx, doc in enumerate(docs)
]
message.db_docs = [Doc(**doc) for doc in docs]
else:
if knowledge_basename:
docs = DocRetrieval.run(query, knowledge_basename, top_k, score_threshold, self.embed_config, self.kb_root_path)
message.db_docs = [Doc(**doc) for doc in docs]
return message
def get_code_retrieval(self, message: Message) -> Message:
query = message.input_query
query = message.role_content
code_engine_name = message.code_engine_name
history_node_list = message.history_node_list
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type,
llm_config=self.llm_config, embed_config=self.embed_config,)
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
use_nh = message.use_nh
local_graph_path = message.local_graph_path
if self.code_retrieval:
code_docs = self.code_retrieval.run(
query, history_node_list=history_node_list, search_type=message.cb_search_type,
code_limit=1
)
else:
code_docs = CodeRetrieval.run(code_engine_name, query, code_limit=message.top_k, history_node_list=history_node_list, search_type=message.cb_search_type,
llm_config=self.llm_config, embed_config=self.embed_config,
use_nh=use_nh, local_graph_path=local_graph_path)
message.code_docs = [CodeDoc(**doc) for doc in code_docs]
# related_nodes = [doc.get_related_node() for idx, doc in enumerate(message.code_docs) if idx==0],
# history_node_list.extend([node[0] for node in related_nodes])
return message
def get_tool_retrieval(self, message: Message) -> Message:
@ -160,6 +197,7 @@ class MessageUtils:
if code_answer.code_exe_type == "error" else f"The return information after executing the above code is {code_answer.code_exe_response}.\n"
observation_message = Message(
user_name=message.user_name,
role_name="observation",
role_type="function", #self.role.role_type,
role_content="",
@ -190,6 +228,7 @@ class MessageUtils:
def tool_step(self, message: Message) -> Message:
'''execute tool'''
observation_message = Message(
user_name=message.user_name,
role_name="observation",
role_type="function", #self.role.role_type,
role_content="\n**Observation:** there is no tool can execute\n",
@ -226,7 +265,7 @@ class MessageUtils:
return message, observation_message
def parser(self, message: Message) -> Message:
''''''
'''parse llm output into dict'''
content = message.role_content
# parse start
parsed_dict = parse_text_to_dict(content)

View File

@ -5,6 +5,8 @@ import importlib
import copy
from loguru import logger
from langchain.schema import BaseRetriever
from coagent.connector.agents import BaseAgent
from coagent.connector.chains import BaseChain
from coagent.connector.schema import (
@ -18,9 +20,6 @@ from coagent.connector.message_process import MessageUtils
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.base_configs.env_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
# from configs.model_config import JUPYTER_WORK_PATH, KB_ROOT_PATH
# from configs.server_config import SANDBOX_SERVER
role_configs = load_role_configs(AGETN_CONFIGS)
chain_configs = load_chain_configs(CHAIN_CONFIGS)
@ -39,20 +38,24 @@ class BasePhase:
kb_root_path: str = KB_ROOT_PATH,
jupyter_work_path: str = JUPYTER_WORK_PATH,
sandbox_server: dict = {},
embed_config: EmbedConfig = EmbedConfig(),
llm_config: LLMConfig = LLMConfig(),
embed_config: EmbedConfig = None,
llm_config: LLMConfig = None,
task: Task = None,
base_phase_config: Union[dict, str] = PHASE_CONFIGS,
base_chain_config: Union[dict, str] = CHAIN_CONFIGS,
base_role_config: Union[dict, str] = AGETN_CONFIGS,
chains: List[BaseChain] = [],
doc_retrieval: Union[BaseRetriever] = None,
code_retrieval = None,
search_retrieval = None,
log_verbose: str = "0"
) -> None:
#
self.phase_name = phase_name
self.do_summary = False
self.do_search = False
self.do_code_retrieval = False
self.do_doc_retrieval = False
self.do_search = search_retrieval is not None
self.do_code_retrieval = code_retrieval is not None
self.do_doc_retrieval = doc_retrieval is not None
self.do_tool_retrieval = False
# memory_pool dont have specific order
# self.memory_pool = Memory(messages=[])
@ -62,12 +65,15 @@ class BasePhase:
self.jupyter_work_path = jupyter_work_path
self.kb_root_path = kb_root_path
self.log_verbose = max(os.environ.get("log_verbose", "0"), log_verbose)
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, log_verbose)
# TODO透传
self.doc_retrieval = doc_retrieval
self.code_retrieval = code_retrieval
self.search_retrieval = search_retrieval
self.message_utils = MessageUtils(None, sandbox_server, jupyter_work_path, embed_config, llm_config, kb_root_path, doc_retrieval, code_retrieval, search_retrieval, log_verbose)
self.global_memory = Memory(messages=[])
self.phase_memory: List[Memory] = []
# according phase name to init the phase contains
self.chains: List[BaseChain] = self.init_chains(
self.chains: List[BaseChain] = chains if chains else self.init_chains(
phase_name,
phase_config,
task=task,
@ -90,7 +96,9 @@ class BasePhase:
kb_root_path=kb_root_path
)
def astep(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
def astep(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
if reinit_memory:
self.memory_manager.re_init(reinit_memory)
self.memory_manager.append(query)
summary_message = None
chain_message = Memory(messages=[])
@ -139,8 +147,8 @@ class BasePhase:
message.role_name = self.phase_name
yield message, local_phase_memory
def step(self, query: Message, history: Memory = None) -> Tuple[Message, Memory]:
for message, local_phase_memory in self.astep(query, history=history):
def step(self, query: Message, history: Memory = None, reinit_memory=False) -> Tuple[Message, Memory]:
for message, local_phase_memory in self.astep(query, history=history, reinit_memory=reinit_memory):
pass
return message, local_phase_memory
@ -194,6 +202,9 @@ class BasePhase:
sandbox_server=self.sandbox_server,
jupyter_work_path=self.jupyter_work_path,
kb_root_path=self.kb_root_path,
doc_retrieval=self.doc_retrieval,
code_retrieval=self.code_retrieval,
search_retrieval=self.search_retrieval,
log_verbose=self.log_verbose
)
if agent_config.role.agent_type == "SelectorAgent":
@ -205,7 +216,7 @@ class BasePhase:
group_base_agent = baseAgent(
role=group_agent_config.role,
prompt_config = group_agent_config.prompt_config,
prompt_manager_type=agent_config.prompt_manager_type,
prompt_manager_type=group_agent_config.prompt_manager_type,
task = task,
memory = memory,
chat_turn=group_agent_config.chat_turn,
@ -216,6 +227,9 @@ class BasePhase:
sandbox_server=self.sandbox_server,
jupyter_work_path=self.jupyter_work_path,
kb_root_path=self.kb_root_path,
doc_retrieval=self.doc_retrieval,
code_retrieval=self.code_retrieval,
search_retrieval=self.search_retrieval,
log_verbose=self.log_verbose
)
base_agent.group_agents.append(group_base_agent)
@ -223,13 +237,16 @@ class BasePhase:
agents.append(base_agent)
chain_instance = BaseChain(
agents, chain_config.chat_turn,
do_checker=chain_configs[chain_name].do_checker,
chain_config,
agents,
jupyter_work_path=self.jupyter_work_path,
sandbox_server=self.sandbox_server,
embed_config=self.embed_config,
llm_config=self.llm_config,
kb_root_path=self.kb_root_path,
doc_retrieval=self.doc_retrieval,
code_retrieval=self.code_retrieval,
search_retrieval=self.search_retrieval,
log_verbose=self.log_verbose
)
chains.append(chain_instance)

View File

@ -0,0 +1,2 @@
from .prompt_manager import PromptManager
from .extend_manager import *

View File

@ -0,0 +1,45 @@
from coagent.connector.schema import Message
from .prompt_manager import PromptManager
class Code2DocPM(PromptManager):
def handle_code_snippet(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message: Message = kwargs['previous_agent_message']
code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
instruction = "A segment of code that contains the function or method to be documented.\n"
return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
def handle_specific_objective(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message: Message = kwargs['previous_agent_message']
specific_objective = previous_agent_message.parsed_output.get("Code Path")
instruction = "Provide the code path of the function or method you wish to document.\n"
s = instruction + f"\n{specific_objective}"
return s
class CodeRetrievalPM(PromptManager):
def handle_code_snippet(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message: Message = kwargs['previous_agent_message']
code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
instruction = "the initial Code or objective that the user wanted to achieve"
return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
def handle_retrieval_codes(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message: Message = kwargs['previous_agent_message']
Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"]
Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"]
instruction = "the initial Code or objective that the user wanted to achieve"
s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)])
return s

View File

@ -0,0 +1,353 @@
import random
from textwrap import dedent
import copy
from loguru import logger
from langchain.agents.tools import Tool
from coagent.connector.schema import Memory, Message
from coagent.connector.utils import extract_section, parse_section
class PromptManager:
def __init__(self, role_prompt="", prompt_config=None, monitored_agents=[], monitored_fields=[]):
self.role_prompt = role_prompt
self.monitored_agents = monitored_agents
self.monitored_fields = monitored_fields
self.field_handlers = {}
self.context_handlers = {}
self.field_order = [] # 用于普通字段的顺序
self.context_order = [] # 单独维护上下文字段的顺序
self.field_descriptions = {}
self.omit_if_empty_flags = {}
self.context_title = "### Context Data\n\n"
self.prompt_config = prompt_config
if self.prompt_config:
self.register_fields_from_config()
def register_field(self, field_name, function=None, title=None, description=None, is_context=True, omit_if_empty=True):
"""
注册一个新的字段及其处理函数
Args:
field_name (str): 字段名称
function (callable): 处理字段数据的函数
title (str, optional): 字段的自定义标题可选
description (str, optional): 字段的描述可选可以是几句话
is_context (bool, optional): 指示该字段是否为上下文字段
omit_if_empty (bool, optional): 如果数据为空是否省略该字段
"""
if not function:
function = self.handle_custom_data
# Register the handler function based on context flag
if is_context:
self.context_handlers[field_name] = function
else:
self.field_handlers[field_name] = function
# Store the custom title if provided and adjust the title prefix based on context
title_prefix = "####" if is_context else "###"
if title is not None:
self.field_descriptions[field_name] = f"{title_prefix} {title}\n\n"
elif description is not None:
# If title is not provided but description is, use description as title
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n{description}\n\n"
else:
# If neither title nor description is provided, use the field name as title
self.field_descriptions[field_name] = f"{title_prefix} {field_name.replace('_', ' ').title()}\n\n"
# Store the omit_if_empty flag for this field
self.omit_if_empty_flags[field_name] = omit_if_empty
if is_context and field_name != 'context_placeholder':
self.context_handlers[field_name] = function
self.context_order.append(field_name)
else:
self.field_handlers[field_name] = function
self.field_order.append(field_name)
def generate_full_prompt(self, **kwargs):
full_prompt = []
context_prompts = [] # 用于收集上下文内容
is_pre_print = kwargs.get("is_pre_print", False) # 用于强制打印所有prompt 字段信息,不管有没有空
# 先处理上下文字段
for field_name in self.context_order:
handler = self.context_handlers[field_name]
processed_prompt = handler(**kwargs)
# Check if the field should be omitted when empty
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
continue # Skip this field
title_or_description = self.field_descriptions.get(field_name, f"#### {field_name.replace('_', ' ').title()}\n\n")
context_prompts.append(title_or_description + processed_prompt + '\n\n')
# 处理普通字段,同时查找 context_placeholder 的位置
for field_name in self.field_order:
if field_name == 'context_placeholder':
# 在 context_placeholder 的位置插入上下文数据
full_prompt.append(self.context_title) # 添加上下文部分的大标题
full_prompt.extend(context_prompts) # 添加收集的上下文内容
else:
handler = self.field_handlers[field_name]
processed_prompt = handler(**kwargs)
# Check if the field should be omitted when empty
if self.omit_if_empty_flags.get(field_name, False) and not processed_prompt and not is_pre_print:
continue # Skip this field
title_or_description = self.field_descriptions.get(field_name, f"### {field_name.replace('_', ' ').title()}\n\n")
full_prompt.append(title_or_description + processed_prompt + '\n\n')
# 返回完整的提示,移除尾部的空行
return ''.join(full_prompt).rstrip('\n')
def pre_print(self, **kwargs):
kwargs.update({"is_pre_print": True})
prompt = self.generate_full_prompt(**kwargs)
input_keys = parse_section(self.role_prompt, 'Response Output Format')
llm_predict = "\n".join([f"**{k}:**" for k in input_keys])
return prompt + "\n\n" + "#"*19 + "\n<<<<LLM PREDICT>>>>\n" + "#"*19 + f"\n\n{llm_predict}\n"
def handle_custom_data(self, **kwargs):
return ""
def handle_tool_data(self, **kwargs):
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message = kwargs.get('previous_agent_message')
tools: list[Tool] = previous_agent_message.tools
if not tools:
return ""
tool_strings = []
for tool in tools:
args_str = f'args: {str(tool.args)}' if tool.args_schema else ""
tool_strings.append(f"{tool.name}: {tool.description}, {args_str}")
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join([tool.name for tool in tools])
tool_prompt = dedent(f"""
Below is a list of tools that are available for your use:
{formatted_tools}
valid "tool_name" value is:
{tool_names}
""")
return tool_prompt
def handle_agent_data(self, **kwargs):
if 'agents' not in kwargs:
return ""
agents = kwargs.get('agents')
random.shuffle(agents)
agent_names = ", ".join([f'{agent.role.role_name}' for agent in agents])
agent_descs = []
for agent in agents:
role_desc = agent.role.role_prompt.split("####")[1]
while "\n\n" in role_desc:
role_desc = role_desc.replace("\n\n", "\n")
role_desc = role_desc.replace("\n", ",")
agent_descs.append(f'"role name: {agent.role.role_name}\nrole description: {role_desc}"')
agents = "\n".join(agent_descs)
agent_prompt = f'''
Please ensure your selection is one of the listed roles. Available roles for selection:
{agents}
Please ensure select the Role from agent names, such as {agent_names}'''
return dedent(agent_prompt)
def handle_doc_info(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message: Message = kwargs.get('previous_agent_message')
db_docs = previous_agent_message.db_docs
search_docs = previous_agent_message.search_docs
code_cocs = previous_agent_message.code_docs
doc_infos = "\n".join([doc.get_snippet() for doc in db_docs] + [doc.get_snippet() for doc in search_docs] +
[doc.get_code() for doc in code_cocs])
return doc_infos
def handle_session_records(self, **kwargs) -> str:
memory_pool: Memory = kwargs.get('memory_pool', Memory(messages=[]))
memory_pool = self.select_memory_by_agent_name(memory_pool)
memory_pool = self.select_memory_by_parsed_key(memory_pool)
return memory_pool.to_str_messages(content_key="parsed_output_list", with_tag=True)
def handle_current_plan(self, **kwargs) -> str:
if 'previous_agent_message' not in kwargs:
return ""
previous_agent_message = kwargs['previous_agent_message']
return previous_agent_message.parsed_output.get("CURRENT_STEP", "")
def handle_agent_profile(self, **kwargs) -> str:
return extract_section(self.role_prompt, 'Agent Profile')
def handle_output_format(self, **kwargs) -> str:
return extract_section(self.role_prompt, 'Response Output Format')
def handle_response(self, **kwargs) -> str:
if 'react_memory' not in kwargs:
return ""
react_memory = kwargs.get('react_memory', Memory(messages=[]))
if react_memory is None:
return ""
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items()]) for _dict in react_memory.get_parserd_output()])
def handle_task_records(self, **kwargs) -> str:
if 'task_memory' not in kwargs:
return ""
task_memory: Memory = kwargs.get('task_memory', Memory(messages=[]))
if task_memory is None:
return ""
return "\n".join(["\n".join([f"**{k}:**\n{v}" for k,v in _dict.items() if k not in ["CURRENT_STEP"]]) for _dict in task_memory.get_parserd_output()])
def handle_previous_message(self, message: Message) -> str:
pass
def handle_message_by_role_name(self, message: Message) -> str:
pass
def handle_message_by_role_type(self, message: Message) -> str:
pass
def handle_current_agent_react_message(self, message: Message) -> str:
pass
def extract_codedoc_info_for_prompt(self, message: Message) -> str:
code_docs = message.code_docs
doc_infos = "\n".join([doc.get_code() for doc in code_docs])
return doc_infos
def select_memory_by_parsed_key(self, memory: Memory) -> Memory:
return Memory(
messages=[self.select_message_by_parsed_key(message) for message in memory.messages
if self.select_message_by_parsed_key(message) is not None]
)
def select_memory_by_agent_name(self, memory: Memory) -> Memory:
return Memory(
messages=[self.select_message_by_agent_name(message) for message in memory.messages
if self.select_message_by_agent_name(message) is not None]
)
def select_message_by_agent_name(self, message: Message) -> Message:
# assume we focus all agents
if self.monitored_agents == []:
return message
return None if message is None or message.role_name not in self.monitored_agents else self.select_message_by_parsed_key(message)
def select_message_by_parsed_key(self, message: Message) -> Message:
# assume we focus all key contents
if message is None:
return message
if self.monitored_fields == []:
return message
message_c = copy.deepcopy(message)
message_c.parsed_output = {k: v for k,v in message_c.parsed_output.items() if k in self.monitored_fields}
message_c.parsed_output_list = [{k: v for k,v in parsed_output.items() if k in self.monitored_fields} for parsed_output in message_c.parsed_output_list]
return message_c
def get_memory(self, content_key="role_content"):
return self.memory.to_tuple_messages(content_key="step_content")
def get_memory_str(self, content_key="role_content"):
return "\n".join([": ".join(i) for i in self.memory.to_tuple_messages(content_key="step_content")])
def register_fields_from_config(self):
for prompt_field in self.prompt_config:
function_name = prompt_field.function_name
# 检查function_name是否是self的一个方法
if function_name and hasattr(self, function_name):
function = getattr(self, function_name)
else:
function = self.handle_custom_data
self.register_field(prompt_field.field_name,
function=function,
title=prompt_field.title,
description=prompt_field.description,
is_context=prompt_field.is_context,
omit_if_empty=prompt_field.omit_if_empty)
def register_standard_fields(self):
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
self.register_field('session_records', function=self.handle_session_records, is_context=True)
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
def register_executor_fields(self):
self.register_field('agent_profile', function=self.handle_agent_profile, is_context=False)
self.register_field('tool_information', function=self.handle_tool_data, is_context=False)
self.register_field('context_placeholder', is_context=True) # 用于标记上下文数据部分的位置
self.register_field('reference_documents', function=self.handle_doc_info, is_context=True)
self.register_field('session_records', function=self.handle_session_records, is_context=True)
self.register_field('current_plan', function=self.handle_current_plan, is_context=True)
self.register_field('output_format', function=self.handle_output_format, title='Response Output Format', is_context=False)
self.register_field('response', function=self.handle_response, is_context=False, omit_if_empty=False)
def register_fields_from_dict(self, fields_dict):
# 使用字典注册字段的函数
for field_name, field_config in fields_dict.items():
function_name = field_config.get('function', None)
title = field_config.get('title', None)
description = field_config.get('description', None)
is_context = field_config.get('is_context', True)
omit_if_empty = field_config.get('omit_if_empty', True)
# 检查function_name是否是self的一个方法
if function_name and hasattr(self, function_name):
function = getattr(self, function_name)
else:
function = self.handle_custom_data
# 调用已存在的register_field方法注册字段
self.register_field(field_name, function=function, title=title, description=description, is_context=is_context, omit_if_empty=omit_if_empty)
def main():
manager = PromptManager()
manager.register_standard_fields()
manager.register_field('agents_work_progress', title=f"Agents' Work Progress", is_context=True)
# 创建数据字典
data_dict = {
"agent_profile": "这是代理配置文件...",
# "tool_list": "这是工具列表...",
"reference_documents": "这是参考文档...",
"session_records": "这是会话记录...",
"agents_work_progress": "这是代理工作进展...",
"output_format": "这是预期的输出格式...",
# "response": "这是生成或继续回应的指令...",
"response": "",
"test": 'xxxxx'
}
# 组合完整的提示
full_prompt = manager.generate_full_prompt(data_dict)
print(full_prompt)
if __name__ == "__main__":
main()

View File

@ -215,15 +215,15 @@ class Env(BaseModel):
class Role(BaseModel):
role_type: str
role_name: str
role_desc: str
agent_type: str = ""
role_desc: str = ""
agent_type: str = "BaseAgent"
role_prompt: str = ""
template_prompt: str = ""
class ChainConfig(BaseModel):
chain_name: str
chain_type: str
chain_type: str = "BaseChain"
agents: List[str]
do_checker: bool = False
chat_turn: int = 1

View File

@ -131,6 +131,9 @@ class Memory(BaseModel):
# logger.debug(f"{message.role_name}: {message.parsed_output_list}")
# return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list[1:]]
return [parsed_output for message in self.messages for parsed_output in message.parsed_output_list]
def get_spec_parserd_output(self, ):
return [message.spec_parsed_output for message in self.messages]
def get_rolenames(self, ):
''''''

View File

@ -7,6 +7,7 @@ from .general_schema import *
class Message(BaseModel):
chat_index: str = None
user_name: str = "default"
role_name: str
role_type: str
role_prompt: str = None
@ -53,6 +54,8 @@ class Message(BaseModel):
cb_search_type: str = None
search_engine_name: str = None
top_k: int = 3
use_nh: bool = True
local_graph_path: str = ''
score_threshold: float = 1.0
do_doc_retrieval: bool = False
do_code_retrieval: bool = False

View File

@ -72,20 +72,25 @@ def parse_text_to_dict(text):
def parse_dict_to_dict(parsed_dict) -> dict:
code_pattern = r'```python\n(.*?)```'
tool_pattern = r'```json\n(.*?)```'
java_pattern = r'```java\n(.*?)```'
pattern_dict = {"code": code_pattern, "json": tool_pattern}
pattern_dict = {"code": code_pattern, "json": tool_pattern, "java": java_pattern}
spec_parsed_dict = copy.deepcopy(parsed_dict)
for key, pattern in pattern_dict.items():
for k, text in parsed_dict.items():
# Search for the code block
if not isinstance(text, str): continue
if not isinstance(text, str):
spec_parsed_dict[k] = text
continue
_match = re.search(pattern, text, re.DOTALL)
if _match:
# Add the code block to the dictionary
try:
spec_parsed_dict[key] = json.loads(_match.group(1).strip())
spec_parsed_dict[k] = json.loads(_match.group(1).strip())
except:
spec_parsed_dict[key] = _match.group(1).strip()
spec_parsed_dict[k] = _match.group(1).strip()
break
return spec_parsed_dict

View File

@ -43,7 +43,7 @@ class NebulaHandler:
elif self.space_name:
cypher = f'USE {self.space_name};{cypher}'
logger.debug(cypher)
# logger.debug(cypher)
resp = session.execute(cypher)
if format_res:
@ -247,6 +247,24 @@ class NebulaHandler:
res = self.execute_cypher(cypher, self.space_name)
return self.result_to_dict(res)
def get_all_vertices(self,):
'''
get all vertices
@return:
'''
cypher = "MATCH (v) RETURN v;"
res = self.execute_cypher(cypher, self.space_name)
return self.result_to_dict(res)
def get_relative_vertices(self, vertice):
'''
get all vertices
@return:
'''
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertice}' RETURN id(v2) as id;'''
res = self.execute_cypher(cypher, self.space_name)
return self.result_to_dict(res)
def result_to_dict(self, result) -> dict:
"""
build list for each column, and transform to dataframe

View File

@ -6,6 +6,7 @@ import os
import pickle
import uuid
import warnings
from enum import Enum
from pathlib import Path
from typing import (
Any,
@ -22,10 +23,22 @@ import numpy as np
from langchain.docstore.base import AddableMixin, Docstore
from langchain.docstore.document import Document
from langchain.docstore.in_memory import InMemoryDocstore
# from langchain.docstore.in_memory import InMemoryDocstore
from .in_memory import InMemoryDocstore
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
from langchain.vectorstores.utils import maximal_marginal_relevance
class DistanceStrategy(str, Enum):
"""Enumerator of the Distance strategies for calculating distances
between vectors."""
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
DOT_PRODUCT = "DOT_PRODUCT"
JACCARD = "JACCARD"
COSINE = "COSINE"
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
@ -219,6 +232,9 @@ class FAISS(VectorStore):
if self._normalize_L2:
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k if filter is None else fetch_k)
# 经过normalize的结果会超出1
if self._normalize_L2:
scores = np.array([row / np.linalg.norm(row) if np.max(row) > 1 else row for row in scores])
docs = []
for j, i in enumerate(indices[0]):
if i == -1:
@ -565,7 +581,7 @@ class FAISS(VectorStore):
vecstore = cls(
embedding.embed_query,
index,
InMemoryDocstore(),
InMemoryDocstore({}),
{},
normalize_L2=normalize_L2,
distance_strategy=distance_strategy,

View File

@ -10,13 +10,14 @@ from loguru import logger
# from configs.model_config import EMBEDDING_MODEL
from coagent.embeddings.openai_embedding import OpenAIEmbedding
from coagent.embeddings.huggingface_embedding import HFEmbedding
from coagent.llm_models.llm_config import EmbedConfig
def get_embedding(
engine: str,
text_list: list,
model_path: str = "text2vec-base-chinese",
embedding_device: str = "cpu",
embed_config: EmbedConfig = None,
):
'''
get embedding
@ -25,8 +26,12 @@ def get_embedding(
@return:
'''
emb_res = {}
if engine == 'openai':
if embed_config and embed_config.langchain_embeddings:
emb_res = embed_config.langchain_embeddings.embed_documents(text_list)
emb_res = {
text_list[idx]: emb_res[idx] for idx in range(len(text_list))
}
elif engine == 'openai':
oae = OpenAIEmbedding()
emb_res = oae.get_emb(text_list)
elif engine == 'model':

View File

@ -0,0 +1,49 @@
"""Simple in memory docstore in the form of a dict."""
from typing import Dict, List, Optional, Union
from langchain.docstore.base import AddableMixin, Docstore
from langchain.docstore.document import Document
class InMemoryDocstore(Docstore, AddableMixin):
"""Simple in memory docstore in the form of a dict."""
def __init__(self, _dict: Optional[Dict[str, Document]] = None):
"""Initialize with dict."""
self._dict = _dict if _dict is not None else {}
def add(self, texts: Dict[str, Document]) -> None:
"""Add texts to in memory dictionary.
Args:
texts: dictionary of id -> document.
Returns:
None
"""
overlapping = set(texts).intersection(self._dict)
if overlapping:
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
self._dict = {**self._dict, **texts}
def delete(self, ids: List) -> None:
"""Deleting IDs from in memory dictionary."""
overlapping = set(ids).intersection(self._dict)
if not overlapping:
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
for _id in ids:
self._dict.pop(_id)
def search(self, search: str) -> Union[str, Document]:
"""Search via direct lookup.
Args:
search: id of a document to search for.
Returns:
Document if found, else error message.
"""
if search not in self._dict:
return f"ID {search} not found."
else:
return self._dict[search]

View File

@ -1,6 +1,8 @@
import os
from functools import lru_cache
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.base import Embeddings
# from configs.model_config import embedding_model_dict
from loguru import logger
@ -12,8 +14,11 @@ def load_embeddings(model: str, device: str, embedding_model_dict: dict):
return embeddings
@lru_cache(1)
def load_embeddings_from_path(model_path: str, device: str):
# @lru_cache(1)
def load_embeddings_from_path(model_path: str, device: str, langchain_embeddings: Embeddings = None):
if langchain_embeddings:
return langchain_embeddings
embeddings = HuggingFaceEmbeddings(model_name=model_path,
model_kwargs={'device': device})
return embeddings

View File

@ -1,8 +1,8 @@
from .openai_model import getChatModel, getExtraModel, getChatModelFromConfig
from .openai_model import getExtraModel, getChatModelFromConfig
from .llm_config import LLMConfig, EmbedConfig
__all__ = [
"getChatModel", "getExtraModel", "getChatModelFromConfig",
"getExtraModel", "getChatModelFromConfig",
"LLMConfig", "EmbedConfig"
]

View File

@ -1,6 +1,9 @@
from dataclasses import dataclass
from typing import List, Union
from langchain.embeddings.base import Embeddings
from langchain.llms.base import LLM, BaseLLM
@dataclass
@ -12,7 +15,8 @@ class LLMConfig:
stop: Union[List[str], str] = None,
api_key: str = "",
api_base_url: str = "",
model_device: str = "cpu",
model_device: str = "cpu", # unusewill delete it
llm: LLM = None,
**kwargs
):
@ -21,7 +25,7 @@ class LLMConfig:
self.stop: Union[List[str], str] = stop
self.api_key: str = api_key
self.api_base_url: str = api_base_url
self.model_device: str = model_device
self.llm: LLM = llm
#
self.check_config()
@ -42,6 +46,7 @@ class EmbedConfig:
embed_model_path: str = "",
embed_engine: str = "",
model_device: str = "cpu",
langchain_embeddings: Embeddings = None,
**kwargs
):
self.embed_model: str = embed_model
@ -51,6 +56,8 @@ class EmbedConfig:
self.api_key: str = api_key
self.api_base_url: str = api_base_url
#
self.langchain_embeddings = langchain_embeddings
#
self.check_config()
def check_config(self, ):

View File

@ -1,38 +1,54 @@
import os
from typing import Union, Optional, List
from loguru import logger
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.llms.base import LLM
from .llm_config import LLMConfig
# from configs.model_config import (llm_model_dict, LLM_MODEL)
def getChatModel(callBack: AsyncIteratorCallbackHandler = None, temperature=0.3, stop=None):
if callBack is None:
class CustomLLMModel:
def __init__(self, llm: LLM):
self.llm: LLM = llm
def __call__(self, prompt: str,
stop: Optional[List[str]] = None):
return self.llm(prompt, stop)
def _call(self, prompt: str,
stop: Optional[List[str]] = None):
return self.llm(prompt, stop)
def predict(self, prompt: str,
stop: Optional[List[str]] = None):
return self.llm(prompt, stop)
def batch(self, prompts: str,
stop: Optional[List[str]] = None):
return [self.llm(prompt, stop) for prompt in prompts]
def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ) -> Union[ChatOpenAI, LLM]:
# logger.debug(f"llm type is {type(llm_config.llm)}")
if llm_config is None:
model = ChatOpenAI(
streaming=True,
verbose=True,
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL,
temperature=temperature,
stop=stop
openai_api_key=os.environ.get("api_key"),
openai_api_base=os.environ.get("api_base_url"),
model_name=os.environ.get("LLM_MODEL", "gpt-3.5-turbo"),
temperature=os.environ.get("temperature", 0.5),
stop=os.environ.get("stop", ""),
)
else:
model = ChatOpenAI(
streaming=True,
verbose=True,
callBack=[callBack],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL,
temperature=temperature,
stop=stop
)
return model
return model
def getChatModelFromConfig(llm_config: LLMConfig, callBack: AsyncIteratorCallbackHandler = None, ):
if llm_config and llm_config.llm and isinstance(llm_config.llm, LLM):
return CustomLLMModel(llm=llm_config.llm)
if callBack is None:
model = ChatOpenAI(
streaming=True,

View File

@ -0,0 +1,5 @@
# from .base_retrieval import *
# __all__ = [
# "IMRertrieval", "BaseDocRetrieval", "BaseCodeRetrieval", "BaseSearchRetrieval"
# ]

View File

@ -0,0 +1,75 @@
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.base_configs.env_config import KB_ROOT_PATH
from coagent.tools import DocRetrieval, CodeRetrieval
class IMRertrieval:
def __init__(self,):
'''
init your personal attributes
'''
pass
def run(self, ):
'''
execute interface, and can use init' attributes
'''
pass
class BaseDocRetrieval(IMRertrieval):
def __init__(self, knowledge_base_name: str, search_top=5, score_threshold=1.0, embed_config: EmbedConfig=EmbedConfig(), kb_root_path: str=KB_ROOT_PATH):
self.knowledge_base_name = knowledge_base_name
self.search_top = search_top
self.score_threshold = score_threshold
self.embed_config = embed_config
self.kb_root_path = kb_root_path
def run(self, query: str, search_top=None, score_threshold=None, ):
docs = DocRetrieval.run(
query=query, knowledge_base_name=self.knowledge_base_name,
search_top=search_top or self.search_top,
score_threshold=score_threshold or self.score_threshold,
embed_config=self.embed_config,
kb_root_path=self.kb_root_path
)
return docs
class BaseCodeRetrieval(IMRertrieval):
def __init__(self, code_base_name, embed_config: EmbedConfig, llm_config: LLMConfig, search_type = 'tag', code_limit = 1, local_graph_path: str=""):
self.code_base_name = code_base_name
self.embed_config = embed_config
self.llm_config = llm_config
self.search_type = search_type
self.code_limit = code_limit
self.use_nh: bool = False
self.local_graph_path: str = local_graph_path
def run(self, query, history_node_list=[], search_type = None, code_limit=None):
code_docs = CodeRetrieval.run(
code_base_name=self.code_base_name,
query=query,
history_node_list=history_node_list,
code_limit=code_limit or self.code_limit,
search_type=search_type or self.search_type,
llm_config=self.llm_config,
embed_config=self.embed_config,
use_nh=self.use_nh,
local_graph_path=self.local_graph_path
)
return code_docs
class BaseSearchRetrieval(IMRertrieval):
def __init__(self, ):
pass
def run(self, ):
pass

View File

@ -0,0 +1,6 @@
from .json_loader import JSONLoader
from .jsonl_loader import JSONLLoader
__all__ = [
"JSONLoader", "JSONLLoader"
]

View File

@ -0,0 +1,61 @@
import json
from pathlib import Path
from typing import AnyStr, Callable, Dict, List, Optional, Union
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from coagent.utils.common_utils import read_json_file
class JSONLoader(BaseLoader):
def __init__(
self,
file_path: Union[str, Path],
schema_key: str = "all_text",
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
):
self.file_path = Path(file_path).resolve()
self.schema_key = schema_key
self._content_key = content_key
self._metadata_func = metadata_func
self._text_content = text_content
def load(self, ) -> List[Document]:
"""Load and return documents from the JSON file."""
docs: List[Document] = []
datas = read_json_file(self.file_path)
self._parse(datas, docs)
return docs
def _parse(self, datas: List, docs: List[Document]) -> None:
for idx, sample in enumerate(datas):
metadata = dict(
source=str(self.file_path),
seq_num=idx,
)
text = sample.get(self.schema_key, "")
docs.append(Document(page_content=text, metadata=metadata))
def load_and_split(
self, text_splitter: Optional[TextSplitter] = None
) -> List[Document]:
"""Load Documents and split into chunks. Chunks are returned as Documents.
Args:
text_splitter: TextSplitter instance to use for splitting documents.
Defaults to RecursiveCharacterTextSplitter.
Returns:
List of Documents.
"""
if text_splitter is None:
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
else:
_text_splitter = text_splitter
docs = self.load()
return _text_splitter.split_documents(docs)

View File

@ -0,0 +1,62 @@
import json
from pathlib import Path
from typing import AnyStr, Callable, Dict, List, Optional, Union
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from coagent.utils.common_utils import read_jsonl_file
class JSONLLoader(BaseLoader):
def __init__(
self,
file_path: Union[str, Path],
schema_key: str = "all_text",
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
):
self.file_path = Path(file_path).resolve()
self.schema_key = schema_key
self._content_key = content_key
self._metadata_func = metadata_func
self._text_content = text_content
def load(self, ) -> List[Document]:
"""Load and return documents from the JSON file."""
docs: List[Document] = []
datas = read_jsonl_file(self.file_path)
self._parse(datas, docs)
return docs
def _parse(self, datas: List, docs: List[Document]) -> None:
for idx, sample in enumerate(datas):
metadata = dict(
source=str(self.file_path),
seq_num=idx,
)
text = sample.get(self.schema_key, "")
docs.append(Document(page_content=text, metadata=metadata))
def load_and_split(
self, text_splitter: Optional[TextSplitter] = None
) -> List[Document]:
"""Load Documents and split into chunks. Chunks are returned as Documents.
Args:
text_splitter: TextSplitter instance to use for splitting documents.
Defaults to RecursiveCharacterTextSplitter.
Returns:
List of Documents.
"""
if text_splitter is None:
_text_splitter: TextSplitter = RecursiveCharacterTextSplitter()
else:
_text_splitter = text_splitter
docs = self.load()
return _text_splitter.split_documents(docs)

View File

@ -0,0 +1,3 @@
from .langchain_splitter import LCTextSplitter
__all__ = ["LCTextSplitter"]

View File

@ -0,0 +1,77 @@
import os
import importlib
from loguru import logger
from langchain.document_loaders.base import BaseLoader
from langchain.text_splitter import (
SpacyTextSplitter, RecursiveCharacterTextSplitter
)
# from configs.model_config import (
# CHUNK_SIZE,
# OVERLAP_SIZE,
# ZH_TITLE_ENHANCE
# )
from coagent.utils.path_utils import *
class LCTextSplitter:
'''langchain textsplitter 执行file2text'''
def __init__(
self, filepath: str, text_splitter_name: str = None,
chunk_size: int = 500,
overlap_size: int = 50
):
self.filepath = filepath
self.ext = os.path.splitext(filepath)[-1].lower()
self.text_splitter_name = text_splitter_name
self.chunk_size = chunk_size
self.overlap_size = overlap_size
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.ext}")
self.document_loader_name = get_LoaderClass(self.ext)
def file2text(self, ):
loader = self._load_document()
text_splitter = self._load_text_splitter()
if self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
# docs = loader.load()
docs = loader.load_and_split(text_splitter)
# logger.debug(f"please check your file can be loaded, docs.lens {len(docs)}")
else:
docs = loader.load_and_split(text_splitter)
return docs
def _load_document(self, ) -> BaseLoader:
DocumentLoader = EXT2LOADER_DICT[self.ext]
if self.document_loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
else:
loader = DocumentLoader(self.filepath)
return loader
def _load_text_splitter(self, ):
try:
if self.text_splitter_name is None:
text_splitter = SpacyTextSplitter(
pipeline="zh_core_web_sm",
chunk_size=self.chunk_size,
chunk_overlap=self.overlap_size,
)
self.text_splitter_name = "SpacyTextSplitter"
# elif self.document_loader_name in ["JSONLoader", "JSONLLoader"]:
# text_splitter = None
else:
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
text_splitter = TextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.overlap_size)
except Exception as e:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.overlap_size,
)
return text_splitter

View File

View File

@ -32,8 +32,8 @@ class PyCodeBox(BaseBox):
self.do_check_net = do_check_net
self.use_stop = use_stop
self.jupyter_work_path = jupyter_work_path
asyncio.run(self.astart())
# self.start()
# asyncio.run(self.astart())
self.start()
# logger.info(f"""remote_url: {self.remote_url},
# remote_ip: {self.remote_ip},
@ -199,13 +199,13 @@ class PyCodeBox(BaseBox):
async def _aget_kernelid(self, ) -> None:
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers)
# response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers)
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers) as resp:
async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as resp:
if len(await resp.json()) > 0:
self.kernel_id = (await resp.json())[0]["id"]
else:
async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers) as response:
async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers, timeout=270) as response:
self.kernel_id = (await response.json())["id"]
# if len(response.json()) > 0:
@ -220,41 +220,45 @@ class PyCodeBox(BaseBox):
return False
try:
response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=270)
response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=10)
return response.status_code == 200
except requests.exceptions.ConnectionError:
return False
except requests.exceptions.ReadTimeout:
return False
async def _acheck_connect(self, ) -> bool:
if self.kernel_url == "":
return False
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.kernel_url}?token={self.token}", timeout=270) as resp:
async with session.get(f"{self.kernel_url}?token={self.token}", timeout=10) as resp:
return resp.status == 200
except aiohttp.ClientConnectorError:
pass
return False
except aiohttp.ServerDisconnectedError:
pass
return False
def _check_port(self, ) -> bool:
try:
response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=270)
response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=10)
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
return response.status_code == 200
except requests.exceptions.ConnectionError:
return False
except requests.exceptions.ReadTimeout:
return False
async def _acheck_port(self, ) -> bool:
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) as resp:
async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=10) as resp:
# logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
return resp.status == 200
except aiohttp.ClientConnectorError:
pass
return False
except aiohttp.ServerDisconnectedError:
pass
return False
def _check_connect_success(self, retry_nums: int = 2) -> bool:
if not self.do_check_net: return True
@ -263,7 +267,7 @@ class PyCodeBox(BaseBox):
try:
connect_status = self._check_connect()
if connect_status:
logger.info(f"{self.remote_url} connection success")
# logger.info(f"{self.remote_url} connection success")
return True
except requests.exceptions.ConnectionError:
logger.info(f"{self.remote_url} connection fail")
@ -301,10 +305,12 @@ class PyCodeBox(BaseBox):
else:
# TODO 自动检测本地接口
port_status = self._check_port()
self.kernel_url = self.remote_url + "/api/kernels"
connect_status = self._check_connect()
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
if os.environ.get("log_verbose", "0") >= "2":
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
if port_status and not connect_status:
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}")
logger.error("Port is conflict, please check your codebox's port {self.remote_port}")
if not connect_status:
self.jupyter = subprocess.Popen(
@ -321,14 +327,32 @@ class PyCodeBox(BaseBox):
stdout=subprocess.PIPE,
)
record = []
while True and self.jupyter and len(record)<100:
line = self.jupyter.stderr.readline()
try:
content = line.decode("utf-8")
except:
content = line.decode("gbk")
# logger.debug(content)
record.append(content)
if "control-c" in content.lower():
break
self.kernel_url = self.remote_url + "/api/kernels"
self.do_check_net = True
self._check_connect_success()
self._get_kernelid()
# logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
self.ws = create_connection(self.wc_url, headers=headers)
retry_nums = 3
while retry_nums>=0:
try:
self.ws = create_connection(self.wc_url, headers=headers, timeout=5)
break
except Exception as e:
logger.error(f"create ws connection timeout {e}")
retry_nums -= 1
async def astart(self, ):
'''判断是从外部service执行还是内部启动notebook执行'''
@ -369,10 +393,16 @@ class PyCodeBox(BaseBox):
cwd=self.jupyter_work_path
)
while True and self.jupyter:
record = []
while True and self.jupyter and len(record)<100:
line = self.jupyter.stderr.readline()
# logger.debug(line.decode("gbk"))
if "Control-C" in line.decode("gbk"):
try:
content = line.decode("utf-8")
except:
content = line.decode("gbk")
# logger.debug(content)
record.append(content)
if "control-c" in content.lower():
break
self.kernel_url = self.remote_url + "/api/kernels"
self.do_check_net = True
@ -380,7 +410,15 @@ class PyCodeBox(BaseBox):
await self._aget_kernelid()
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
self.ws = create_connection(self.wc_url, headers=headers)
retry_nums = 3
while retry_nums>=0:
try:
self.ws = create_connection(self.wc_url, headers=headers, timeout=5)
break
except Exception as e:
logger.error(f"create ws connection timeout {e}")
retry_nums -= 1
def status(self,) -> CodeBoxStatus:
if not self.kernel_id:

View File

@ -17,7 +17,7 @@ from coagent.orm.commands import *
from coagent.utils.path_utils import *
from coagent.orm.utils import DocumentFile
from coagent.embeddings.utils import load_embeddings, load_embeddings_from_path
from coagent.text_splitter import LCTextSplitter
from coagent.retrieval.text_splitter import LCTextSplitter
from coagent.llm_models.llm_config import EmbedConfig
@ -46,7 +46,7 @@ class KBService(ABC):
def _load_embeddings(self) -> Embeddings:
# return load_embeddings(self.embed_model, embed_device, embedding_model_dict)
return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device)
return load_embeddings_from_path(self.embed_config.embed_model_path, self.embed_config.model_device, self.embed_config.langchain_embeddings)
def create_kb(self):
"""

View File

@ -20,9 +20,6 @@ from coagent.utils.path_utils import *
from coagent.orm.commands import *
from coagent.db_handler.graph_db_handler.nebula_handler import NebulaHandler
from coagent.db_handler.vector_db_handler.chroma_handler import ChromaHandler
# from configs.server_config import NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT
# from configs.server_config import CHROMA_PERSISTENT_PATH
from coagent.base_configs.env_config import (
CB_ROOT_PATH,
NEBULA_HOST, NEBULA_PORT, NEBULA_USER, NEBULA_PASSWORD, NEBULA_STORAGED_PORT,
@ -58,10 +55,11 @@ async def create_cb(zip_file,
model_name: bool = Body(..., examples=["samples"]),
temperature: bool = Body(..., examples=["samples"]),
model_device: bool = Body(..., examples=["samples"]),
embed_config: EmbedConfig = None,
) -> BaseResponse:
logger.info('cb_name={}, zip_path={}, do_interpret={}'.format(cb_name, code_path, do_interpret))
embed_config: EmbedConfig = EmbedConfig(**locals())
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
llm_config: LLMConfig = LLMConfig(**locals())
# Create selected knowledge base
@ -101,9 +99,10 @@ async def delete_cb(
model_name: bool = Body(..., examples=["samples"]),
temperature: bool = Body(..., examples=["samples"]),
model_device: bool = Body(..., examples=["samples"]),
embed_config: EmbedConfig = None,
) -> BaseResponse:
logger.info('cb_name={}'.format(cb_name))
embed_config: EmbedConfig = EmbedConfig(**locals())
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
llm_config: LLMConfig = LLMConfig(**locals())
# Create selected knowledge base
if not validate_kb_name(cb_name):
@ -143,18 +142,24 @@ def search_code(cb_name: str = Body(..., examples=["sofaboot"]),
model_name: bool = Body(..., examples=["samples"]),
temperature: bool = Body(..., examples=["samples"]),
model_device: bool = Body(..., examples=["samples"]),
use_nh: bool = True,
local_graph_path: str = '',
embed_config: EmbedConfig = None,
) -> dict:
logger.info('cb_name={}'.format(cb_name))
logger.info('query={}'.format(query))
logger.info('code_limit={}'.format(code_limit))
logger.info('search_type={}'.format(search_type))
logger.info('history_node_list={}'.format(history_node_list))
embed_config: EmbedConfig = EmbedConfig(**locals())
if os.environ.get("log_verbose", "0") >= "2":
logger.info(f'local_graph_path={local_graph_path}')
logger.info('cb_name={}'.format(cb_name))
logger.info('query={}'.format(query))
logger.info('code_limit={}'.format(code_limit))
logger.info('search_type={}'.format(search_type))
logger.info('history_node_list={}'.format(history_node_list))
embed_config: EmbedConfig = EmbedConfig(**locals()) if embed_config is None else embed_config
llm_config: LLMConfig = LLMConfig(**locals())
try:
# load codebase
cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config)
cbh = CodeBaseHandler(codebase_name=cb_name, embed_config=embed_config, llm_config=llm_config,
use_nh=use_nh, local_graph_path=local_graph_path)
# search code
context, related_vertices = cbh.search_code(query, search_type=search_type, limit=code_limit)
@ -179,11 +184,13 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
# load codebase
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
password=NEBULA_PASSWORD, space_name=cb_name)
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
if vertex.endswith(".java"):
cypher = f'''MATCH (v1)--(v2:package) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
else:
cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN id(v2) as id;'''
# cypher = f'''MATCH (v1)--(v2) WHERE id(v1) == '{vertex}' RETURN v2;'''
cypher_res = nh.execute_cypher(cypher=cypher, format_res=True)
related_vertices = cypher_res.get('id', [])
related_vertices = [i.as_string() for i in related_vertices]
@ -200,8 +207,8 @@ def search_related_vertices(cb_name: str = Body(..., examples=["sofaboot"]),
def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
vertex: str = Body(..., examples=['***'])) -> dict:
logger.info('cb_name={}'.format(cb_name))
logger.info('vertex={}'.format(vertex))
# logger.info('cb_name={}'.format(cb_name))
# logger.info('vertex={}'.format(vertex))
try:
nh = NebulaHandler(host=NEBULA_HOST, port=NEBULA_PORT, username=NEBULA_USER,
@ -233,7 +240,7 @@ def search_code_by_vertex(cb_name: str = Body(..., examples=["sofaboot"]),
return res
except Exception as e:
logger.exception(e)
return {}
return {'code': ""}
def cb_exists_api(cb_name: str = Body(..., examples=["sofaboot"])) -> bool:

View File

@ -8,17 +8,6 @@ from loguru import logger
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores.utils import DistanceStrategy
# from configs.model_config import (
# KB_ROOT_PATH,
# CACHED_VS_NUM,
# EMBEDDING_MODEL,
# EMBEDDING_DEVICE,
# SCORE_THRESHOLD,
# FAISS_NORMALIZE_L2
# )
# from configs.model_config import embedding_model_dict
from coagent.base_configs.env_config import (
KB_ROOT_PATH,
@ -52,15 +41,15 @@ def load_vector_store(
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
kb_root_path: str = KB_ROOT_PATH,
):
print(f"loading vector store in '{knowledge_base_name}'.")
# print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name, kb_root_path)
if embeddings is None:
embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device)
embeddings = load_embeddings_from_path(embed_config.embed_model_path, embed_config.model_device, embed_config.langchain_embeddings)
if not os.path.exists(vs_path):
os.makedirs(vs_path)
distance_strategy = DistanceStrategy.EUCLIDEAN_DISTANCE
distance_strategy = "EUCLIDEAN_DISTANCE"
if "index.faiss" in os.listdir(vs_path):
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=FAISS_NORMALIZE_L2, distance_strategy=distance_strategy)
else:

View File

@ -9,9 +9,7 @@ from pydantic import BaseModel, Field
from loguru import logger
from coagent.llm_models import LLMConfig, EmbedConfig
from .base_tool import BaseToolModel
from coagent.service.cb_api import search_code
@ -29,7 +27,17 @@ class CodeRetrieval(BaseToolModel):
code: str = Field(..., description="检索代码")
@classmethod
def run(cls, code_base_name, query, code_limit=1, history_node_list=[], search_type="tag", llm_config: LLMConfig=None, embed_config: EmbedConfig=None):
def run(cls,
code_base_name,
query,
code_limit=1,
history_node_list=[],
search_type="tag",
llm_config: LLMConfig=None,
embed_config: EmbedConfig=None,
use_nh: str=True,
local_graph_path: str=''
):
"""excute your tool!"""
search_type = {
@ -45,7 +53,8 @@ class CodeRetrieval(BaseToolModel):
codes = search_code(code_base_name, query, code_limit, search_type=search_type, history_node_list=history_node_list,
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, use_nh=use_nh,
local_graph_path=local_graph_path, embed_config=embed_config
)
return_codes = []
context = codes['context']

View File

@ -5,6 +5,7 @@
@time: 2023/12/14 上午10:24
@desc:
'''
import os
from pydantic import BaseModel, Field
from loguru import logger
@ -40,10 +41,9 @@ class CodeRetrievalSingle(BaseToolModel):
vertex: str = Field(..., description="代码对应 id")
@classmethod
def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, **kargs):
def run(cls, code_base_name, query, embed_config: EmbedConfig, llm_config: LLMConfig, search_type="description", **kargs):
"""excute your tool!"""
search_type = 'description'
code_limit = 1
# default
@ -51,10 +51,11 @@ class CodeRetrievalSingle(BaseToolModel):
history_node_list=[],
embed_engine=embed_config.embed_engine, embed_model=embed_config.embed_model, embed_model_path=embed_config.embed_model_path,
model_device=embed_config.model_device, model_name=llm_config.model_name, temperature=llm_config.temperature,
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key
api_base_url=llm_config.api_base_url, api_key=llm_config.api_key, embed_config=embed_config, use_nh=kargs.get("use_nh", True),
local_graph_path=kargs.get("local_graph_path", "")
)
logger.debug(search_result)
if os.environ.get("log_verbose", "0") >= "3":
logger.debug(search_result)
code = search_result['context']
vertex = search_result['related_vertices'][0]
# logger.debug(f"code: {code}, vertex: {vertex}")
@ -83,7 +84,7 @@ class RelatedVerticesRetrival(BaseToolModel):
def run(cls, code_base_name: str, vertex: str, **kargs):
"""execute your tool!"""
related_vertices = search_related_vertices(cb_name=code_base_name, vertex=vertex)
logger.debug(f"related_vertices: {related_vertices}")
# logger.debug(f"related_vertices: {related_vertices}")
return related_vertices
@ -110,6 +111,6 @@ class Vertex2Code(BaseToolModel):
else:
vertex = vertex.strip(' "')
logger.info(f'vertex={vertex}')
# logger.info(f'vertex={vertex}')
res = search_code_by_vertex(cb_name=code_base_name, vertex=vertex)
return res

View File

@ -2,11 +2,7 @@ from pydantic import BaseModel, Field
from loguru import logger
from coagent.llm_models.llm_config import EmbedConfig
from .base_tool import BaseToolModel
from coagent.service.kb_api import search_docs

View File

@ -9,8 +9,10 @@ import numpy as np
from loguru import logger
from .base_tool import BaseToolModel
from duckduckgo_search import DDGS
try:
from duckduckgo_search import DDGS
except:
logger.warning("can't find duckduckgo_search, if you need it, please `pip install duckduckgo_search`")
class DDGSTool(BaseToolModel):

View File

@ -0,0 +1,89 @@
import json
def class_info_decode(data):
'''解析class的相关信息'''
params_dict = {}
for i in data:
_params_dict = {}
for ii in i:
for k, v in ii.items():
if k=="origin_query": continue
if k == "Code Path":
_params_dict["code_path"] = v.split("#")[0]
_params_dict["function_name"] = ".".join(v.split("#")[1:])
if k == "Class Description":
_params_dict["ClassDescription"] = v
if k == "Class Base":
_params_dict["ClassBase"] = v
if k=="Init Parameters":
_params_dict["Parameters"] = v
code_path = _params_dict["code_path"]
params_dict.setdefault(code_path, []).append(_params_dict)
return params_dict
def method_info_decode(data):
params_dict = {}
for i in data:
_params_dict = {}
for ii in i:
for k, v in ii.items():
if k=="origin_query": continue
if k == "Code Path":
_params_dict["code_path"] = v.split("#")[0]
_params_dict["function_name"] = ".".join(v.split("#")[1:])
if k == "Return Value Description":
_params_dict["Returns"] = v
if k == "Return Type":
_params_dict["ReturnType"] = v
if k=="Parameters":
_params_dict["Parameters"] = v
code_path = _params_dict["code_path"]
params_dict.setdefault(code_path, []).append(_params_dict)
return params_dict
def encode2md(data, md_format):
md_dict = {}
for code_path, params_list in data.items():
for params in params_list:
params["Parameters_text"] = "\n".join([f"{param['param']}({param['param_type']})-{param['param_description']}"
for param in params["Parameters"]])
# params.delete("Parameters")
text=md_format.format(**params)
md_dict.setdefault(code_path, []).append(text)
return md_dict
method_text_md = '''> {function_name}
| Column Name | Content |
|-----------------|-----------------|
| Parameters | {Parameters_text} |
| Returns | {Returns} |
| Return type | {ReturnType} |
'''
class_text_md = '''> {code_path}
Bases: {ClassBase}
{ClassDescription}
{Parameters_text}
'''

View File

@ -7,7 +7,7 @@ from pathlib import Path
from io import BytesIO
from fastapi import Body, File, Form, Body, Query, UploadFile
from tempfile import SpooledTemporaryFile
import json
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
@ -109,4 +109,6 @@ def get_uploadfile(file: Union[str, Path, bytes], filename=None) -> UploadFile:
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
temp_file.write(file.read())
temp_file.seek(0)
return UploadFile(file=temp_file, filename=filename)
return UploadFile(file=temp_file, filename=filename)

View File

@ -1,7 +1,7 @@
import os
from langchain.document_loaders import CSVLoader, PyPDFLoader, UnstructuredFileLoader, TextLoader, PythonLoader
from coagent.document_loaders import JSONLLoader, JSONLoader
from coagent.retrieval.document_loaders import JSONLLoader, JSONLoader
# from configs.model_config import (
# embedding_model_dict,
# KB_ROOT_PATH,

View File

@ -21,17 +21,20 @@ JUPYTER_WORK_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath
# WEB_CRAWL存储路径
WEB_CRAWL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "knowledge_base")
# NEBULA_DATA存储路径
NELUBA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/neluba_data")
NEBULA_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/nebula_data")
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NELUBA_PATH]:
# CHROMA 存储路径
CHROMA_PERSISTENT_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data/chroma_data")
for _path in [LOG_PATH, SOURCE_PATH, KB_ROOT_PATH, CB_ROOT_PATH, NLTK_DATA_PATH, JUPYTER_WORK_PATH, WEB_CRAWL_PATH, NEBULA_PATH, CHROMA_PERSISTENT_PATH]:
if not os.path.exists(_path):
os.makedirs(_path, exist_ok=True)
#
path_envt_dict = {
"LOG_PATH": LOG_PATH, "SOURCE_PATH": SOURCE_PATH, "KB_ROOT_PATH": KB_ROOT_PATH,
"NLTK_DATA_PATH":NLTK_DATA_PATH, "JUPYTER_WORK_PATH": JUPYTER_WORK_PATH,
"WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NELUBA_PATH": NELUBA_PATH
"WEB_CRAWL_PATH": WEB_CRAWL_PATH, "NEBULA_PATH": NEBULA_PATH,
"CHROMA_PERSISTENT_PATH": CHROMA_PERSISTENT_PATH
}
for path_name, _path in path_envt_dict.items():
os.environ[path_name] = _path

View File

@ -33,7 +33,7 @@ except:
pass
# add your openai key
OPENAI_API_BASE = "http://openai.com/v1/chat/completions"
OPENAI_API_BASE = "https://api.openai.com/v1"
os.environ["API_BASE_URL"] = OPENAI_API_BASE
os.environ["OPENAI_API_KEY"] = "sk-xx"
openai.api_key = "sk-xx"

View File

@ -58,9 +58,6 @@ NEBULA_GRAPH_SERVER = {
"docker_port": NEBULA_PORT
}
# chroma conf
CHROMA_PERSISTENT_PATH = '/home/user/chatbot/data/chroma_data'
# sandbox api server
SANDBOX_CONTRAINER_NAME = "devopsgpt_sandbox"
SANDBOX_IMAGE_NAME = "devopsgpt:py39"

View File

@ -15,11 +15,11 @@ from coagent.connector.schema import Message
#
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
# log-levelprint prompt和llm predict
os.environ["log_verbose"] = "0"
os.environ["log_verbose"] = "2"
phase_name = "baseGroupPhase"
llm_config = LLMConfig(
model_name=LLM_MODEL, model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
model_name=LLM_MODEL, api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(

View File

@ -17,7 +17,7 @@ os.environ["log_verbose"] = "2"
phase_name = "baseTaskPhase"
llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(

View File

@ -0,0 +1,135 @@
# encoding: utf-8
'''
@author: 温进
@file: codeChatPhaseLocal_example.py
@time: 2024/1/31 下午4:32
@desc:
'''
import os, sys, requests
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import requests
from typing import List
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
sys.path.append(src_dir)
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
from configs.server_config import SANDBOX_SERVER
from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.connector.phase import BasePhase
from coagent.connector.schema import Message, Memory
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
# log-levelprint prompt和llm predict
os.environ["log_verbose"] = "1"
llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
)
# delete codebase
codebase_name = 'client_local'
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = True
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.delete_codebase(codebase_name=codebase_name)
# initialize codebase
codebase_name = 'client_local'
code_path = '/Users/bingxu/Desktop/工作/大模型/chatbot/test_code_repo/client'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
code_path = "/home/user/client"
use_nh = True
do_interpret = True
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.import_code(do_interpret=do_interpret)
# chat with codebase
phase_name = "codeChatPhase"
phase = BasePhase(
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
)
# remove 这个函数是做什么的 => 基于标签
# 有没有函数已经实现了从字符串删除指定字符串的功能使用的话可以怎么使用写个java代码 => 基于描述
# 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述
## 需要启动容器中的nebula采用use_nh=True来构建代码库是可以通过cypher来查询
# round-1
# query_content = "代码一共有多少类"
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
# )
#
# output_message1, _ = phase.step(query)
# print(output_message1)
# round-2
# query_content = "代码库里有哪些函数返回5个就行"
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
# )
# output_message2, _ = phase.step(query)
# print(output_message2)
# round-3
query_content = "remove 这个函数是做什么的"
query = Message(
role_name="user", role_type="human",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
use_nh=False, local_graph_path=CB_ROOT_PATH
)
output_message3, output_memory3 = phase.step(query)
print(output_memory3.to_str_messages(return_all=True, content_key="parsed_output_list"))
#
# # round-4
query_content = "有没有函数已经实现了从字符串删除指定字符串的功能使用的话可以怎么使用写个java代码"
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description",
use_nh=False, local_graph_path=CB_ROOT_PATH
)
output_message4, output_memory4 = phase.step(query)
print(output_memory4.to_str_messages(return_all=True, content_key="parsed_output_list"))
# # round-5
query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串"
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="description",
use_nh=False, local_graph_path=CB_ROOT_PATH
)
output_message5, output_memory5 = phase.step(query)
print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list"))

View File

@ -17,13 +17,14 @@ os.environ["log_verbose"] = "2"
phase_name = "codeChatPhase"
llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese",
embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
)
phase = BasePhase(
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
@ -35,50 +36,56 @@ phase = BasePhase(
# 有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串 => 基于描述
# round-1
query_content = "代码一共有多少类"
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher"
)
output_message1, _ = phase.step(query)
# query_content = "代码一共有多少类"
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
# )
#
# output_message1, _ = phase.step(query)
# print(output_message1)
# round-2
query_content = "代码库里有哪些函数返回5个就行"
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="cypher"
)
output_message2, _ = phase.step(query)
# query_content = "代码库里有哪些函数返回5个就行"
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="cypher"
# )
# output_message2, _ = phase.step(query)
# print(output_message2)
# round-3
#
# # round-3
query_content = "remove 这个函数是做什么的"
query = Message(
role_name="user", role_type="human",
role_name="user", role_type="human",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag"
)
output_message3, _ = phase.step(query)
print(output_message3)
# round-4
query_content = "有没有函数已经实现了从字符串删除指定字符串的功能使用的话可以怎么使用写个java代码"
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description"
)
output_message4, _ = phase.step(query)
# round-5
query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串"
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="description"
)
output_message5, output_memory5 = phase.step(query)
print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list"))
#
# # round-4
# query_content = "有没有函数已经实现了从字符串删除指定字符串的功能使用的话可以怎么使用写个java代码"
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description"
# )
# output_message4, _ = phase.step(query)
# print(output_message4)
#
# # round-5
# query_content = "有根据我以下的需求用 java 开发一个方法:输入为字符串,将输入中的 .java 字符串给删除掉,然后返回新的字符串"
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="description"
# )
# output_message5, output_memory5 = phase.step(query)
# print(output_message5)
#
# print(output_memory5.to_str_messages(return_all=True, content_key="parsed_output_list"))

View File

@ -0,0 +1,507 @@
import os, sys, json
from loguru import logger
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
sys.path.append(src_dir)
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
from configs.server_config import SANDBOX_SERVER
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.connector.phase import BasePhase
from coagent.connector.agents import BaseAgent
from coagent.connector.schema import Message
from coagent.tools import CodeRetrievalSingle
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
import importlib
# 定义一个新的agent类
class CodeGenDocer(BaseAgent):
def start_action_step(self, message: Message) -> Message:
'''do action before agent predict '''
# 根据问题获取代码片段和节点信息
action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query,
llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag")
current_vertex = action_json['vertex']
message.customed_kargs["Code Snippet"] = action_json["code"]
message.customed_kargs['Current_Vertex'] = current_vertex
return message
# add agent or prompt_manager class
agent_module = importlib.import_module("coagent.connector.agents")
setattr(agent_module, 'CodeGenDocer', CodeGenDocer)
# log-levelprint prompt和llm predict
os.environ["log_verbose"] = "1"
phase_name = "code2DocsGroup"
llm_config = LLMConfig(
model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
)
# initialize codebase
# delete codebase
codebase_name = 'client_local'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = False
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.delete_codebase(codebase_name=codebase_name)
# load codebase
codebase_name = 'client_local'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = True
do_interpret = True
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.import_code(do_interpret=do_interpret)
# 根据前面的load过程进行初始化
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
phase = BasePhase(
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
)
for vertex_type in ["class", "method"]:
vertexes = cbh.search_vertices(vertex_type=vertex_type)
logger.info(f"vertexes={vertexes}")
# round-1
docs = []
for vertex in vertexes:
vertex = vertex.split("-")[0] # -为method的参数
query_content = f"{vertex_type}节点 {vertex}生成文档"
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh,
local_graph_path=CB_ROOT_PATH,
)
output_message, output_memory = phase.step(query, reinit_memory=True)
# print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
docs.append(output_memory.get_spec_parserd_output())
os.makedirs(f"{CB_ROOT_PATH}/docs", exist_ok=True)
with open(f"{CB_ROOT_PATH}/docs/raw_{vertex_type}.json", "w") as f:
json.dump(docs, f)
# 下面把生成的文档信息转换成markdown文本
from coagent.utils.code2doc_util import *
import json
with open(f"{CB_ROOT_PATH}/docs/raw_method.json", "r") as f:
method_raw_data = json.load(f)
with open(f"{CB_ROOT_PATH}/docs/raw_class.json", "r") as f:
class_raw_data = json.load(f)
method_data = method_info_decode(method_raw_data)
class_data = class_info_decode(class_raw_data)
method_mds = encode2md(method_data, method_text_md)
class_mds = encode2md(class_data, class_text_md)
docs_dict = {}
for k,v in class_mds.items():
method_textmds = method_mds.get(k, [])
for vv in v:
# 理论上只有一个
text_md = vv
for method_textmd in method_textmds:
text_md += "\n<br>" + method_textmd
docs_dict.setdefault(k, []).append(text_md)
with open(f"{CB_ROOT_PATH}//docs/{k}.md", "w") as f:
f.write(text_md)
####################################
######## 下面是完整的复现过程 ########
####################################
# import os, sys, requests
# from loguru import logger
# src_dir = os.path.join(
# os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# )
# sys.path.append(src_dir)
# from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
# from configs.server_config import SANDBOX_SERVER
# from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
# from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
# from coagent.connector.phase import BasePhase
# from coagent.connector.agents import BaseAgent, SelectorAgent
# from coagent.connector.chains import BaseChain
# from coagent.connector.schema import (
# Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus
# )
# from coagent.connector.memory_manager import BaseMemoryManager
# from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
# from coagent.connector.prompt_manager.prompt_manager import PromptManager
# from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
# import importlib
# from loguru import logger
# from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
# # update new agent configs
# codeGenDocGroup_PROMPT = """#### Agent Profile
# Your goal is to response according the Context Data's information with the role that will best facilitate a solution, taking into account all relevant context (Context) provided.
# When you need to select the appropriate role for handling a user's query, carefully read the provided role names, role descriptions and tool list.
# ATTENTION: response carefully referenced "Response Output Format" in format.
# #### Input Format
# #### Response Output Format
# **Code Path:** Extract the paths for the class/method/function that need to be addressed from the context
# **Role:** Select the role from agent names
# """
# classGenDoc_PROMPT = """#### Agent Profile
# As an advanced code documentation generator, you are proficient in translating class definitions into comprehensive documentation with a focus on instantiation parameters.
# Your specific task is to parse the given code snippet of a class, extract information regarding its instantiation parameters.
# ATTENTION: response carefully in "Response Output Format".
# #### Input Format
# **Code Snippet:** Provide the full class definition, including the constructor and any parameters it may require for instantiation.
# #### Response Output Format
# **Class Base:** Specify the base class or interface from which the current class extends, if any.
# **Class Description:** Offer a brief description of the class's purpose and functionality.
# **Init Parameters:** List each parameter from construct. For each parameter, provide:
# - `param`: The parameter name
# - `param_description`: A concise explanation of the parameter's purpose.
# - `param_type`: The data type of the parameter, if explicitly defined.
# ```json
# [
# {
# "param": "parameter_name",
# "param_description": "A brief description of what this parameter is used for.",
# "param_type": "The data type of the parameter"
# },
# ...
# ]
# ```
# If no parameter for construct, return
# ```json
# []
# ```
# """
# funcGenDoc_PROMPT = """#### Agent Profile
# You are a high-level code documentation assistant, skilled at extracting information from function/method code into detailed and well-structured documentation.
# ATTENTION: response carefully in "Response Output Format".
# #### Input Format
# **Code Path:** Provide the code path of the function or method you wish to document.
# This name will be used to identify and extract the relevant details from the code snippet provided.
# **Code Snippet:** A segment of code that contains the function or method to be documented.
# #### Response Output Format
# **Class Description:** Offer a brief description of the method(function)'s purpose and functionality.
# **Parameters:** Extract parameter for the specific function/method Code from Code Snippet. For parameter, provide:
# - `param`: The parameter name
# - `param_description`: A concise explanation of the parameter's purpose.
# - `param_type`: The data type of the parameter, if explicitly defined.
# ```json
# [
# {
# "param": "parameter_name",
# "param_description": "A brief description of what this parameter is used for.",
# "param_type": "The data type of the parameter"
# },
# ...
# ]
# ```
# If no parameter for function/method, return
# ```json
# []
# ```
# **Return Value Description:** Describe what the function/method returns upon completion.
# **Return Type:** Indicate the type of data the function/method returns (e.g., string, integer, object, void).
# """
# CODE_GENERATE_GROUP_PROMPT_CONFIGS = [
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
# {"field_name": 'agent_infomation', "function_name": 'handle_agent_data', "is_context": False, "omit_if_empty": False},
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
# {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
# {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
# ]
# CODE_GENERATE_DOC_PROMPT_CONFIGS = [
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
# {"field_name": 'Specific Objective', "function_name": 'handle_specific_objective'},
# {"field_name": 'Code Snippet', "function_name": 'handle_code_snippet'},
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
# ]
# class CodeGenDocPM(PromptManager):
# def handle_code_snippet(self, **kwargs) -> str:
# if 'previous_agent_message' not in kwargs:
# return ""
# previous_agent_message: Message = kwargs['previous_agent_message']
# code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
# current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
# instruction = "A segment of code that contains the function or method to be documented.\n"
# return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
# def handle_specific_objective(self, **kwargs) -> str:
# if 'previous_agent_message' not in kwargs:
# return ""
# previous_agent_message: Message = kwargs['previous_agent_message']
# specific_objective = previous_agent_message.parsed_output.get("Code Path")
# instruction = "Provide the code path of the function or method you wish to document.\n"
# s = instruction + f"\n{specific_objective}"
# return s
# from coagent.tools import CodeRetrievalSingle
# # 定义一个新的agent类
# class CodeGenDocer(BaseAgent):
# def start_action_step(self, message: Message) -> Message:
# '''do action before agent predict '''
# # 根据问题获取代码片段和节点信息
# action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query,
# llm_config=self.llm_config, embed_config=self.embed_config, local_graph_path=message.local_graph_path, use_nh=message.use_nh,search_type="tag")
# current_vertex = action_json['vertex']
# message.customed_kargs["Code Snippet"] = action_json["code"]
# message.customed_kargs['Current_Vertex'] = current_vertex
# return message
# # add agent or prompt_manager class
# agent_module = importlib.import_module("coagent.connector.agents")
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
# setattr(agent_module, 'CodeGenDocer', CodeGenDocer)
# setattr(prompt_manager_module, 'CodeGenDocPM', CodeGenDocPM)
# AGETN_CONFIGS.update({
# "classGenDoc": {
# "role": {
# "role_prompt": classGenDoc_PROMPT,
# "role_type": "assistant",
# "role_name": "classGenDoc",
# "role_desc": "",
# "agent_type": "CodeGenDocer"
# },
# "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS,
# "prompt_manager_type": "CodeGenDocPM",
# "chat_turn": 1,
# "focus_agents": [],
# "focus_message_keys": [],
# },
# "funcGenDoc": {
# "role": {
# "role_prompt": funcGenDoc_PROMPT,
# "role_type": "assistant",
# "role_name": "funcGenDoc",
# "role_desc": "",
# "agent_type": "CodeGenDocer"
# },
# "prompt_config": CODE_GENERATE_DOC_PROMPT_CONFIGS,
# "prompt_manager_type": "CodeGenDocPM",
# "chat_turn": 1,
# "focus_agents": [],
# "focus_message_keys": [],
# },
# "codeGenDocsGrouper": {
# "role": {
# "role_prompt": codeGenDocGroup_PROMPT,
# "role_type": "assistant",
# "role_name": "codeGenDocsGrouper",
# "role_desc": "",
# "agent_type": "SelectorAgent"
# },
# "prompt_config": CODE_GENERATE_GROUP_PROMPT_CONFIGS,
# "group_agents": ["classGenDoc", "funcGenDoc"],
# "chat_turn": 1,
# },
# })
# # update new chain configs
# CHAIN_CONFIGS.update({
# "codeGenDocsGroupChain": {
# "chain_name": "codeGenDocsGroupChain",
# "chain_type": "BaseChain",
# "agents": ["codeGenDocsGrouper"],
# "chat_turn": 1,
# "do_checker": False,
# "chain_prompt": ""
# }
# })
# # update phase configs
# PHASE_CONFIGS.update({
# "codeGenDocsGroup": {
# "phase_name": "codeGenDocsGroup",
# "phase_type": "BasePhase",
# "chains": ["codeGenDocsGroupChain"],
# "do_summary": False,
# "do_search": False,
# "do_doc_retrieval": False,
# "do_code_retrieval": False,
# "do_tool_retrieval": False,
# },
# })
# role_configs = load_role_configs(AGETN_CONFIGS)
# chain_configs = load_chain_configs(CHAIN_CONFIGS)
# phase_configs = load_phase_configs(PHASE_CONFIGS)
# # log-levelprint prompt和llm predict
# os.environ["log_verbose"] = "1"
# phase_name = "codeGenDocsGroup"
# llm_config = LLMConfig(
# model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
# api_base_url=os.environ["API_BASE_URL"], temperature=0.3
# )
# embed_config = EmbedConfig(
# embed_engine="model", embed_model="text2vec-base-chinese",
# embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
# )
# # initialize codebase
# # delete codebase
# codebase_name = 'client_local'
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
# use_nh = False
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
# cbh.delete_codebase(codebase_name=codebase_name)
# # load codebase
# codebase_name = 'client_local'
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
# use_nh = False
# do_interpret = True
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
# cbh.import_code(do_interpret=do_interpret)
# phase = BasePhase(
# phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
# embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
# )
# for vertex_type in ["class", "method"]:
# vertexes = cbh.search_vertices(vertex_type=vertex_type)
# logger.info(f"vertexes={vertexes}")
# # round-1
# docs = []
# for vertex in vertexes:
# vertex = vertex.split("-")[0] # -为method的参数
# query_content = f"为{vertex_type}节点 {vertex}生成文档"
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name="client_local", score_threshold=1.0, top_k=3, cb_search_type="tag", use_nh=use_nh,
# local_graph_path=CB_ROOT_PATH,
# )
# output_message, output_memory = phase.step(query, reinit_memory=True)
# # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
# docs.append(output_memory.get_spec_parserd_output())
# import json
# os.makedirs("/home/user/code_base/docs", exist_ok=True)
# with open(f"/home/user/code_base/docs/raw_{vertex_type}.json", "w") as f:
# json.dump(docs, f)
# # 下面把生成的文档信息转换成markdown文本
# from coagent.utils.code2doc_util import *
# import json
# with open(f"/home/user/code_base/docs/raw_method.json", "r") as f:
# method_raw_data = json.load(f)
# with open(f"/home/user/code_base/docs/raw_class.json", "r") as f:
# class_raw_data = json.load(f)
# method_data = method_info_decode(method_raw_data)
# class_data = class_info_decode(class_raw_data)
# method_mds = encode2md(method_data, method_text_md)
# class_mds = encode2md(class_data, class_text_md)
# docs_dict = {}
# for k,v in class_mds.items():
# method_textmds = method_mds.get(k, [])
# for vv in v:
# # 理论上只有一个
# text_md = vv
# for method_textmd in method_textmds:
# text_md += "\n<br>" + method_textmd
# docs_dict.setdefault(k, []).append(text_md)
# with open(f"/home/user/code_base/docs/{k}.md", "w") as f:
# f.write(text_md)

View File

@ -0,0 +1,444 @@
import os, sys, json
from loguru import logger
src_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
sys.path.append(src_dir)
from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
from configs.server_config import SANDBOX_SERVER
from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
from coagent.connector.phase import BasePhase
from coagent.connector.agents import BaseAgent
from coagent.connector.schema import Message
from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
import importlib
from loguru import logger
from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
# 定义一个新的agent类
class CodeRetrieval(BaseAgent):
def start_action_step(self, message: Message) -> Message:
'''do action before agent predict '''
# 根据问题获取代码片段和节点信息
action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag",
local_graph_path=message.local_graph_path, use_nh=message.use_nh)
current_vertex = action_json['vertex']
message.customed_kargs["Code Snippet"] = action_json["code"]
message.customed_kargs['Current_Vertex'] = current_vertex
# 获取邻近节点
action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex'])
# 获取邻近节点所有代码
relative_vertex = []
retrieval_Codes = []
for vertex in action_json["vertices"]:
# 由于代码是文件级别,所以相同文件代码不再获取
# logger.debug(f"{current_vertex}, {vertex}")
current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex
if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue
action_json = Vertex2Code.run(message.code_engine_name, vertex)
if action_json["code"]:
retrieval_Codes.append(action_json["code"])
relative_vertex.append(vertex)
#
message.customed_kargs["Retrieval_Codes"] = retrieval_Codes
message.customed_kargs["Relative_vertex"] = relative_vertex
return message
# add agent or prompt_manager class
agent_module = importlib.import_module("coagent.connector.agents")
setattr(agent_module, 'CodeRetrieval', CodeRetrieval)
llm_config = LLMConfig(
model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(
embed_engine="model", embed_model="text2vec-base-chinese",
embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
)
## initialize codebase
# delete codebase
codebase_name = 'client_local'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = False
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.delete_codebase(codebase_name=codebase_name)
# load codebase
codebase_name = 'client_local'
code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
use_nh = True
do_interpret = True
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
cbh.import_code(do_interpret=do_interpret)
cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
llm_config=llm_config, embed_config=embed_config)
vertexes = cbh.search_vertices(vertex_type="class")
# log-levelprint prompt和llm predict
os.environ["log_verbose"] = "0"
phase_name = "code2Tests"
phase = BasePhase(
phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
)
# round-1
logger.debug(vertexes)
test_cases = []
for vertex in vertexes:
query_content = f"{vertex}生成可执行的测例 "
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
)
output_message, output_memory = phase.step(query, reinit_memory=True)
# print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
print(output_memory.get_spec_parserd_output())
values = output_memory.get_spec_parserd_output()
test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]}
test_cases.append(test_code)
os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True)
with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f:
f.write(test_code["Test Code"])
break
# import os, sys, json
# from loguru import logger
# src_dir = os.path.join(
# os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# )
# sys.path.append(src_dir)
# from configs.model_config import KB_ROOT_PATH, JUPYTER_WORK_PATH, CB_ROOT_PATH
# from configs.server_config import SANDBOX_SERVER
# from coagent.tools import toLangchainTools, TOOL_DICT, TOOL_SETS
# from coagent.llm_models.llm_config import EmbedConfig, LLMConfig
# from coagent.connector.phase import BasePhase
# from coagent.connector.agents import BaseAgent
# from coagent.connector.chains import BaseChain
# from coagent.connector.schema import (
# Message, Memory, load_role_configs, load_phase_configs, load_chain_configs, ActionStatus
# )
# from coagent.connector.memory_manager import BaseMemoryManager
# from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
# from coagent.connector.prompt_manager.prompt_manager import PromptManager
# from coagent.codechat.codebase_handler.codebase_handler import CodeBaseHandler
# import importlib
# from loguru import logger
# # 下面给定了一份代码片段,以及来自于它的依赖类、依赖方法相关的代码片段,你需要判断是否为这段指定代码片段生成测例。
# # 理论上所有代码都需要写测例,但是受限于人的精力不可能覆盖所有代码
# # 考虑以下因素进行裁剪:
# # 功能性: 如果它实现的是一个具体的功能或逻辑,则通常需要编写测试用例以验证其正确性。
# # 复杂性: 如果代码较为尤其是包含多个条件判断、循环、异常处理等的代码更可能隐藏bug因此应该编写测试用例。如果代码涉及复杂的算法或者逻辑那么编写测试用例可以帮助确保逻辑的正确性并在未来的重构中防止引入错误。
# # 关键性: 如果它是关键路径的一部分或影响到核心功能,那么它就需要被测试。对于核心业务逻辑或者系统的关键组件,应当编写全面的测试用例来确保功能的正确性和稳定性。
# # 依赖性: 如果代码有外部依赖,可能需要编写集成测试或模拟这些依赖进行单元测试。
# # 用户输入: 如果代码处理用户输入,尤其是来自外部的、非受控的输入,那么创建测试用例来检查输入验证和处理是很重要的。
# # 频繁更改:对于经常需要更新或修改的代码,有相应的测试用例可以确保更改不会破坏现有功能。
# # 代码公开或重用:如果代码将被公开或用于其他项目,编写测试用例可以提高代码的可信度和易用性。
# # update new agent configs
# judgeGenerateTests_PROMPT = """#### Agent Profile
# When determining the necessity of writing test cases for a given code snippet,
# it's essential to evaluate its interactions with dependent classes and methods (retrieved code snippets),
# in addition to considering these critical factors:
# 1. Functionality: If it implements a concrete function or logic, test cases are typically necessary to verify its correctness.
# 2. Complexity: If the code is complex, especially if it contains multiple conditional statements, loops, exceptions handling, etc.,
# it's more likely to harbor bugs, and thus test cases should be written.
# If the code involves complex algorithms or logic, then writing test cases can help ensure the accuracy of the logic and prevent errors during future refactoring.
# 3. Criticality: If it's part of the critical path or affects core functionalities, then it needs to be tested.
# Comprehensive test cases should be written for core business logic or key components of the system to ensure the correctness and stability of the functionality.
# 4. Dependencies: If the code has external dependencies, integration testing may be necessary, or mocking these dependencies during unit testing might be required.
# 5. User Input: If the code handles user input, especially from unregulated external sources, creating test cases to check input validation and handling is important.
# 6. Frequent Changes: For code that requires regular updates or modifications, having the appropriate test cases ensures that changes do not break existing functionalities.
# #### Input Format
# **Code Snippet:** the initial Code or objective that the user wanted to achieve
# **Retrieval Code Snippets:** These are the associated code segments that the main Code Snippet depends on.
# Examine these snippets to understand how they interact with the main snippet and to determine how they might affect the overall functionality.
# #### Response Output Format
# **Action Status:** Set to 'finished' or 'continued'.
# If set to 'finished', the code snippet does not warrant the generation of a test case.
# If set to 'continued', the code snippet necessitates the creation of a test case.
# **REASON:** Justify the selection of 'finished' or 'continued', contemplating the decision through a step-by-step rationale.
# """
# generateTests_PROMPT = """#### Agent Profile
# As an agent specializing in software quality assurance,
# your mission is to craft comprehensive test cases that bolster the functionality, reliability, and robustness of a specified Code Snippet.
# This task is to be carried out with a keen understanding of the snippet's interactions with its dependent classes and methods—collectively referred to as Retrieval Code Snippets.
# Analyze the details given below to grasp the code's intended purpose, its inherent complexity, and the context within which it operates.
# Your constructed test cases must thoroughly examine the various factors influencing the code's quality and performance.
# ATTENTION: response carefully referenced "Response Output Format" in format.
# Each test case should include:
# 1. clear description of the test purpose.
# 2. The input values or conditions for the test.
# 3. The expected outcome or assertion for the test.
# 4. Appropriate tags (e.g., 'functional', 'integration', 'regression') that classify the type of test case.
# 5. these test code should have package and import
# #### Input Format
# **Code Snippet:** the initial Code or objective that the user wanted to achieve
# **Retrieval Code Snippets:** These are the interrelated pieces of code sourced from the codebase, which support or influence the primary Code Snippet.
# #### Response Output Format
# **SaveFileName:** construct a local file name based on Question and Context, such as
# ```java
# package/class.java
# ```
# **Test Code:** generate the test code for the current Code Snippet.
# ```java
# ...
# ```
# """
# CODE_GENERATE_TESTS_PROMPT_CONFIGS = [
# {"field_name": 'agent_profile', "function_name": 'handle_agent_profile', "is_context": False},
# # {"field_name": 'tool_information',"function_name": 'handle_tool_data', "is_context": False},
# {"field_name": 'context_placeholder', "function_name": '', "is_context": True},
# # {"field_name": 'reference_documents', "function_name": 'handle_doc_info'},
# {"field_name": 'session_records', "function_name": 'handle_session_records'},
# {"field_name": 'code_snippet', "function_name": 'handle_code_snippet'},
# {"field_name": 'retrieval_codes', "function_name": 'handle_retrieval_codes', "description": ""},
# {"field_name": 'output_format', "function_name": 'handle_output_format', 'title': 'Response Output Format', "is_context": False},
# {"field_name": 'begin!!!', "function_name": 'handle_response', "is_context": False, "omit_if_empty": False}
# ]
# class CodeRetrievalPM(PromptManager):
# def handle_code_snippet(self, **kwargs) -> str:
# if 'previous_agent_message' not in kwargs:
# return ""
# previous_agent_message: Message = kwargs['previous_agent_message']
# code_snippet = previous_agent_message.customed_kargs.get("Code Snippet", "")
# current_vertex = previous_agent_message.customed_kargs.get("Current_Vertex", "")
# instruction = "the initial Code or objective that the user wanted to achieve"
# return instruction + "\n" + f"name: {current_vertex}\n{code_snippet}"
# def handle_retrieval_codes(self, **kwargs) -> str:
# if 'previous_agent_message' not in kwargs:
# return ""
# previous_agent_message: Message = kwargs['previous_agent_message']
# Retrieval_Codes = previous_agent_message.customed_kargs["Retrieval_Codes"]
# Relative_vertex = previous_agent_message.customed_kargs["Relative_vertex"]
# instruction = "the initial Code or objective that the user wanted to achieve"
# s = instruction + "\n" + "\n".join([f"name: {vertext}\n{code}" for vertext, code in zip(Relative_vertex, Retrieval_Codes)])
# return s
# AGETN_CONFIGS.update({
# "CodeJudger": {
# "role": {
# "role_prompt": judgeGenerateTests_PROMPT,
# "role_type": "assistant",
# "role_name": "CodeJudger",
# "role_desc": "",
# "agent_type": "CodeRetrieval"
# # "agent_type": "BaseAgent"
# },
# "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS,
# "prompt_manager_type": "CodeRetrievalPM",
# "chat_turn": 1,
# "focus_agents": [],
# "focus_message_keys": [],
# },
# "generateTests": {
# "role": {
# "role_prompt": generateTests_PROMPT,
# "role_type": "assistant",
# "role_name": "generateTests",
# "role_desc": "",
# "agent_type": "CodeRetrieval"
# # "agent_type": "BaseAgent"
# },
# "prompt_config": CODE_GENERATE_TESTS_PROMPT_CONFIGS,
# "prompt_manager_type": "CodeRetrievalPM",
# "chat_turn": 1,
# "focus_agents": [],
# "focus_message_keys": [],
# },
# })
# # update new chain configs
# CHAIN_CONFIGS.update({
# "codeRetrievalChain": {
# "chain_name": "codeRetrievalChain",
# "chain_type": "BaseChain",
# "agents": ["CodeJudger", "generateTests"],
# "chat_turn": 1,
# "do_checker": False,
# "chain_prompt": ""
# }
# })
# # update phase configs
# PHASE_CONFIGS.update({
# "codeGenerateTests": {
# "phase_name": "codeGenerateTests",
# "phase_type": "BasePhase",
# "chains": ["codeRetrievalChain"],
# "do_summary": False,
# "do_search": False,
# "do_doc_retrieval": False,
# "do_code_retrieval": False,
# "do_tool_retrieval": False,
# },
# })
# role_configs = load_role_configs(AGETN_CONFIGS)
# chain_configs = load_chain_configs(CHAIN_CONFIGS)
# phase_configs = load_phase_configs(PHASE_CONFIGS)
# from coagent.tools import CodeRetrievalSingle, RelatedVerticesRetrival, Vertex2Code
# # 定义一个新的agent类
# class CodeRetrieval(BaseAgent):
# def start_action_step(self, message: Message) -> Message:
# '''do action before agent predict '''
# # 根据问题获取代码片段和节点信息
# action_json = CodeRetrievalSingle.run(message.code_engine_name, message.origin_query, llm_config=self.llm_config, embed_config=self.embed_config, search_type="tag",
# local_graph_path=message.local_graph_path, use_nh=message.use_nh)
# current_vertex = action_json['vertex']
# message.customed_kargs["Code Snippet"] = action_json["code"]
# message.customed_kargs['Current_Vertex'] = current_vertex
# # 获取邻近节点
# action_json = RelatedVerticesRetrival.run(message.code_engine_name, message.customed_kargs['Current_Vertex'])
# # 获取邻近节点所有代码
# relative_vertex = []
# retrieval_Codes = []
# for vertex in action_json["vertices"]:
# # 由于代码是文件级别,所以相同文件代码不再获取
# # logger.debug(f"{current_vertex}, {vertex}")
# current_vertex_name = current_vertex.replace("#", "").replace(".java", "" ) if current_vertex.endswith(".java") else current_vertex
# if current_vertex_name.split("#")[0] == vertex.split("#")[0]: continue
# action_json = Vertex2Code.run(message.code_engine_name, vertex)
# if action_json["code"]:
# retrieval_Codes.append(action_json["code"])
# relative_vertex.append(vertex)
# #
# message.customed_kargs["Retrieval_Codes"] = retrieval_Codes
# message.customed_kargs["Relative_vertex"] = relative_vertex
# return message
# # add agent or prompt_manager class
# agent_module = importlib.import_module("coagent.connector.agents")
# prompt_manager_module = importlib.import_module("coagent.connector.prompt_manager")
# setattr(agent_module, 'CodeRetrieval', CodeRetrieval)
# setattr(prompt_manager_module, 'CodeRetrievalPM', CodeRetrievalPM)
# # log-levelprint prompt和llm predict
# os.environ["log_verbose"] = "0"
# phase_name = "codeGenerateTests"
# llm_config = LLMConfig(
# model_name="gpt-4", api_key=os.environ["OPENAI_API_KEY"],
# api_base_url=os.environ["API_BASE_URL"], temperature=0.3
# )
# embed_config = EmbedConfig(
# embed_engine="model", embed_model="text2vec-base-chinese",
# embed_model_path=os.path.join(src_dir, "embedding_models/text2vec-base-chinese")
# )
# ## initialize codebase
# # delete codebase
# codebase_name = 'client_local'
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
# use_nh = False
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
# cbh.delete_codebase(codebase_name=codebase_name)
# # load codebase
# codebase_name = 'client_local'
# code_path = "D://chromeDownloads/devopschat-bot/client_v2/client"
# use_nh = True
# do_interpret = True
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
# cbh.import_code(do_interpret=do_interpret)
# cbh = CodeBaseHandler(codebase_name, code_path, crawl_type='dir', use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# llm_config=llm_config, embed_config=embed_config)
# vertexes = cbh.search_vertices(vertex_type="class")
# phase = BasePhase(
# phase_name, sandbox_server=SANDBOX_SERVER, jupyter_work_path=JUPYTER_WORK_PATH,
# embed_config=embed_config, llm_config=llm_config, kb_root_path=KB_ROOT_PATH,
# )
# # round-1
# logger.debug(vertexes)
# test_cases = []
# for vertex in vertexes:
# query_content = f"为{vertex}生成可执行的测例 "
# query = Message(
# role_name="human", role_type="user",
# role_content=query_content, input_query=query_content, origin_query=query_content,
# code_engine_name=codebase_name, score_threshold=1.0, top_k=3, cb_search_type="tag",
# use_nh=use_nh, local_graph_path=CB_ROOT_PATH,
# )
# output_message, output_memory = phase.step(query, reinit_memory=True)
# # print(output_memory.to_str_messages(return_all=True, content_key="parsed_output_list"))
# print(output_memory.get_spec_parserd_output())
# values = output_memory.get_spec_parserd_output()
# test_code = {k:v for i in values for k,v in i.items() if k in ["SaveFileName", "Test Code"]}
# test_cases.append(test_code)
# os.makedirs(f"{CB_ROOT_PATH}/tests", exist_ok=True)
# with open(f"{CB_ROOT_PATH}/tests/{test_code['SaveFileName']}", "w") as f:
# f.write(test_code["Test Code"])

View File

@ -17,7 +17,7 @@ os.environ["log_verbose"] = "2"
phase_name = "codeReactPhase"
llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(

View File

@ -18,8 +18,7 @@ from coagent.connector.schema import (
)
from coagent.connector.memory_manager import BaseMemoryManager
from coagent.connector.configs import AGETN_CONFIGS, CHAIN_CONFIGS, PHASE_CONFIGS, BASE_PROMPT_CONFIGS
from coagent.connector.utils import parse_section
from coagent.connector.prompt_manager import PromptManager
from coagent.connector.prompt_manager.prompt_manager import PromptManager
import importlib
from loguru import logger
@ -230,7 +229,7 @@ os.environ["log_verbose"] = "2"
phase_name = "codeRetrievalPhase"
llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(
@ -246,7 +245,7 @@ query_content = "UtilsTest 这个类中测试了哪些函数,测试的函数代
query = Message(
role_name="human", role_type="user",
role_content=query_content, input_query=query_content, origin_query=query_content,
code_engine_name="client", score_threshold=1.0, top_k=3, cb_search_type="tag"
code_engine_name="client_1", score_threshold=1.0, top_k=3, cb_search_type="tag"
)

View File

@ -24,7 +24,7 @@ os.environ["log_verbose"] = "2"
phase_name = "codeToolReactPhase"
llm_config = LLMConfig(
model_name="gpt-3.5-turbo-0613", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo-0613", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.7
)
embed_config = EmbedConfig(

View File

@ -17,7 +17,7 @@ from coagent.connector.schema import Message, Memory
tools = toLangchainTools([TOOL_DICT[i] for i in TOOL_SETS if i in TOOL_DICT])
llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(

View File

@ -18,7 +18,7 @@ os.environ["log_verbose"] = "0"
phase_name = "metagpt_code_devlop"
llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(

View File

@ -20,7 +20,7 @@ os.environ["log_verbose"] = "2"
phase_name = "searchChatPhase"
llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(

View File

@ -18,7 +18,7 @@ os.environ["log_verbose"] = "2"
phase_name = "toolReactPhase"
llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo",api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)
embed_config = EmbedConfig(

View File

@ -151,9 +151,9 @@ def create_app():
)(delete_cb)
app.post("/code_base/code_base_chat",
tags=["Code Base Management"],
summary="删除 code_base"
)(delete_cb)
tags=["Code Base Management"],
summary="code_base 对话"
)(search_code)
app.get("/code_base/list_code_bases",
tags=["Code Base Management"],

View File

@ -117,7 +117,7 @@ PHASE_CONFIGS.update({
llm_config = LLMConfig(
model_name="gpt-3.5-turbo", model_device="cpu",api_key=os.environ["OPENAI_API_KEY"],
model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"], temperature=0.3
)

View File

@ -98,12 +98,6 @@ def start_docker(client, script_shs, ports, image_name, container_name, mounts=N
network_name ='my_network'
def start_sandbox_service(network_name ='my_network'):
# networks = client.networks.list()
# if any([network_name==i.attrs["Name"] for i in networks]):
# network = client.networks.get(network_name)
# else:
# network = client.networks.create('my_network', driver='bridge')
mount = Mount(
type='bind',
source=os.path.join(src_dir, "jupyter_work"),
@ -114,6 +108,12 @@ def start_sandbox_service(network_name ='my_network'):
# 沙盒的启动与服务的启动是独立的
if SANDBOX_SERVER["do_remote"]:
client = docker.from_env()
networks = client.networks.list()
if any([network_name==i.attrs["Name"] for i in networks]):
network = client.networks.get(network_name)
else:
network = client.networks.create('my_network', driver='bridge')
# 启动容器
logger.info("start container sandbox service")
JUPYTER_WORK_PATH = "/home/user/chatbot/jupyter_work"
@ -150,7 +150,7 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
client = docker.from_env()
logger.info("start container service")
check_process("api.py", do_stop=True)
check_process("sdfile_api.py", do_stop=True)
check_process("llm_api.py", do_stop=True)
check_process("sdfile_api.py", do_stop=True)
check_process("webui.py", do_stop=True)
mount = Mount(
@ -159,27 +159,28 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
target='/home/user/chatbot/',
read_only=False # 如果需要只读访问将此选项设置为True
)
mount_database = Mount(
type='bind',
source=os.path.join(src_dir, "knowledge_base"),
target='/home/user/knowledge_base/',
read_only=False # 如果需要只读访问将此选项设置为True
)
mount_code_database = Mount(
type='bind',
source=os.path.join(src_dir, "code_base"),
target='/home/user/code_base/',
read_only=False # 如果需要只读访问将此选项设置为True
)
# mount_database = Mount(
# type='bind',
# source=os.path.join(src_dir, "knowledge_base"),
# target='/home/user/knowledge_base/',
# read_only=False # 如果需要只读访问将此选项设置为True
# )
# mount_code_database = Mount(
# type='bind',
# source=os.path.join(src_dir, "code_base"),
# target='/home/user/code_base/',
# read_only=False # 如果需要只读访问将此选项设置为True
# )
ports={
f"{API_SERVER['docker_port']}/tcp": f"{API_SERVER['port']}/tcp",
f"{WEBUI_SERVER['docker_port']}/tcp": f"{WEBUI_SERVER['port']}/tcp",
f"{SDFILE_API_SERVER['docker_port']}/tcp": f"{SDFILE_API_SERVER['port']}/tcp",
f"{NEBULA_GRAPH_SERVER['docker_port']}/tcp": f"{NEBULA_GRAPH_SERVER['port']}/tcp"
}
mounts = [mount, mount_database, mount_code_database]
# mounts = [mount, mount_database, mount_code_database]
mounts = [mount]
script_shs = [
"mkdir -p /home/user/logs",
"mkdir -p /home/user/chatbot/logs",
'''
if [ -d "/home/user/chatbot/data/nebula_data/data/meta" ]; then
cp -r /home/user/chatbot/data/nebula_data/data /usr/local/nebula/
@ -197,12 +198,12 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
"pip install jieba",
"pip install duckduckgo-search",
"nohup python chatbot/examples/sdfile_api.py > /home/user/logs/sdfile_api.log 2>&1 &",
"nohup python chatbot/examples/sdfile_api.py > /home/user/chatbot/logs/sdfile_api.log 2>&1 &",
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
nohup python chatbot/examples/api.py > /home/user/logs/api.log 2>&1 &",
nohup python chatbot/examples/api.py > /home/user/chatbot/logs/api.log 2>&1 &",
"nohup python chatbot/examples/llm_api.py > /home/user/llm.log 2>&1 &",
f"export DUCKDUCKGO_PROXY=socks5://host.docker.internal:13659 && export SANDBOX_HOST={sandbox_host} &&\
cd chatbot/examples && nohup streamlit run webui.py > /home/user/logs/start_webui.log 2>&1 &"
cd chatbot/examples && nohup streamlit run webui.py > /home/user/chatbot/logs/start_webui.log 2>&1 &"
]
if check_docker(client, CONTRAINER_NAME, do_stop=True):
container = start_docker(client, script_shs, ports, IMAGE_NAME, CONTRAINER_NAME, mounts, network=network_name)
@ -212,12 +213,9 @@ def start_api_service(sandbox_host=DEFAULT_BIND_HOST):
# 关闭之前启动的docker 服务
# check_docker(client, CONTRAINER_NAME, do_stop=True, )
# api_sh = "nohup python ../coagent/service/api.py > ../logs/api.log 2>&1 &"
api_sh = "nohup python api.py > ../logs/api.log 2>&1 &"
# sdfile_sh = "nohup python ../coagent/service/sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
sdfile_sh = "nohup python sdfile_api.py > ../logs/sdfile_api.log 2>&1 &"
notebook_sh = f"nohup jupyter-notebook --NotebookApp.token=mytoken --port={SANDBOX_SERVER['port']} --allow-root --ip=0.0.0.0 --notebook-dir={JUPYTER_WORK_PATH} --no-browser --ServerApp.disable_check_xsrf=True > ../logs/sandbox.log 2>&1 &"
# llm_sh = "nohup python ../coagent/service/llm_api.py > ../logs/llm_api.log 2>&1 &"
llm_sh = "nohup python llm_api.py > ../logs/llm_api.log 2>&1 &"
webui_sh = "streamlit run webui.py" if USE_TTY else "streamlit run webui.py"

View File

@ -22,7 +22,7 @@ from coagent.service.service_factory import get_cb_details, get_cb_details_by_cb
from coagent.orm import table_init
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict,llm_model_dict
# SENTENCE_SIZE = 100
cell_renderer = JsCode("""function(params) {if(params.value==true){return ''}else{return '×'}}""")
@ -117,6 +117,8 @@ def code_page(api: ApiRequest):
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
embedding_device=EMBEDDING_DEVICE,
llm_model=LLM_MODEL,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
)
st.toast(ret.get("msg", " "))
st.session_state["selected_cb_name"] = cb_name
@ -153,6 +155,8 @@ def code_page(api: ApiRequest):
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
embedding_device=EMBEDDING_DEVICE,
llm_model=LLM_MODEL,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
)
st.toast(ret.get("msg", "删除成功"))
time.sleep(0.05)

View File

@ -11,7 +11,7 @@ from coagent.chat.search_chat import SEARCH_ENGINES
from coagent.connector import PHASE_LIST, PHASE_CONFIGS
from coagent.service.service_factory import get_cb_details_by_cb_name
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, embedding_model_dict, EMBEDDING_ENGINE, KB_ROOT_PATH, llm_model_dict
chat_box = ChatBox(
assistant_avatar="../sources/imgs/devops-chatbot2.png"
)
@ -174,7 +174,7 @@ def dialogue_page(api: ApiRequest):
is_detailed = st.toggle(webui_configs["dialogue"]["phase_toggle_detailed_name"], False)
tool_using_on = st.toggle(
webui_configs["dialogue"]["phase_toggle_doToolUsing"],
PHASE_CONFIGS[choose_phase]["do_using_tool"])
PHASE_CONFIGS[choose_phase].get("do_using_tool", False))
tool_selects = []
if tool_using_on:
with st.expander("工具军火库", True):
@ -183,7 +183,7 @@ def dialogue_page(api: ApiRequest):
TOOL_SETS, ["WeatherInfo"])
search_on = st.toggle(webui_configs["dialogue"]["phase_toggle_doSearch"],
PHASE_CONFIGS[choose_phase]["do_search"])
PHASE_CONFIGS[choose_phase].get("do_search", False))
search_engine, top_k = None, 3
if search_on:
with st.expander(webui_configs["dialogue"]["expander_search_name"], True):
@ -195,7 +195,8 @@ def dialogue_page(api: ApiRequest):
doc_retrieval_on = st.toggle(
webui_configs["dialogue"]["phase_toggle_doDocRetrieval"],
PHASE_CONFIGS[choose_phase]["do_doc_retrieval"])
PHASE_CONFIGS[choose_phase].get("do_doc_retrieval", False)
)
selected_kb, top_k, score_threshold = None, 3, 1.0
if doc_retrieval_on:
with st.expander(webui_configs["dialogue"]["kbase_expander_name"], True):
@ -215,7 +216,7 @@ def dialogue_page(api: ApiRequest):
code_retrieval_on = st.toggle(
webui_configs["dialogue"]["phase_toggle_doCodeRetrieval"],
PHASE_CONFIGS[choose_phase]["do_code_retrieval"])
PHASE_CONFIGS[choose_phase].get("do_code_retrieval", False))
selected_cb, top_k = None, 1
cb_search_type = "tag"
if code_retrieval_on:
@ -296,7 +297,8 @@ def dialogue_page(api: ApiRequest):
r = api.chat_chat(
prompt, history, no_remote_api=True,
embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE,
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE,api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
llm_model=LLM_MODEL)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
@ -362,6 +364,8 @@ def dialogue_page(api: ApiRequest):
"embed_engine": EMBEDDING_ENGINE,
"kb_root_path": KB_ROOT_PATH,
"model_name": LLM_MODEL,
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],
}
text = ""
d = {"docs": []}
@ -405,7 +409,10 @@ def dialogue_page(api: ApiRequest):
api.knowledge_base_chat(
prompt, selected_kb, kb_top_k, score_threshold, history,
embed_model=EMBEDDING_MODEL, embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL)
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
)
):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
@ -415,11 +422,7 @@ def dialogue_page(api: ApiRequest):
# chat_box.update_msg("知识库匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
chat_box.update_msg("{webui_configs['chat']['chatbox_doc_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
# # 判断是否存在代码, 并提高编辑功能,执行功能
# code_text = api.codebox.decode_code_from_text(text)
# GLOBAL_EXE_CODE_TEXT = code_text
# if code_text and code_exec_on:
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
elif dialogue_mode == webui_configs["dialogue"]["mode"][2]:
logger.info('prompt={}'.format(prompt))
logger.info('history={}'.format(history))
@ -438,7 +441,9 @@ def dialogue_page(api: ApiRequest):
cb_search_type=cb_search_type,
no_remote_api=True, embed_model=EMBEDDING_MODEL,
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL
embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
)):
if error_msg := check_error_msg(d):
st.error(error_msg)
@ -448,6 +453,7 @@ def dialogue_page(api: ApiRequest):
chat_box.update_msg(text, element_index=0)
# postprocess
logger.debug(f"d={d}")
text = replace_lt_gt(text)
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
logger.debug('text={}'.format(text))
@ -467,7 +473,9 @@ def dialogue_page(api: ApiRequest):
api.search_engine_chat(
prompt, search_engine, se_top_k, history, embed_model=EMBEDDING_MODEL,
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL)
model_device=EMBEDDING_DEVICE, embed_engine=EMBEDDING_ENGINE, llm_model=LLM_MODEL,
pi_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
@ -477,56 +485,11 @@ def dialogue_page(api: ApiRequest):
# chat_box.update_msg("搜索匹配结果: \n\n".join(d["docs"]), element_index=1, streaming=False)
chat_box.update_msg(text, element_index=0, streaming=False) # 更新最终的字符串,去除光标
chat_box.update_msg(f"{webui_configs['chat']['chatbox_search_result']}: \n\n".join(d["docs"]), element_index=1, streaming=False, state="complete")
# # 判断是否存在代码, 并提高编辑功能,执行功能
# code_text = api.codebox.decode_code_from_text(text)
# GLOBAL_EXE_CODE_TEXT = code_text
# if code_text and code_exec_on:
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
# 将上传文件清空
st.session_state["interpreter_file_key"] += 1
st.experimental_rerun()
# if code_interpreter_on:
# with st.expander(webui_configs['sandbox']['expander_code_name'], False):
# code_part = st.text_area(
# webui_configs['sandbox']['textArea_code_name'], code_text, key="code_text")
# cols = st.columns(2)
# if cols[0].button(
# webui_configs['sandbox']['button_modify_code_name'],
# use_container_width=True,
# ):
# code_text = code_part
# GLOBAL_EXE_CODE_TEXT = code_text
# st.toast(webui_configs['sandbox']['text_modify_code'])
# if cols[1].button(
# webui_configs['sandbox']['button_exec_code_name'],
# use_container_width=True
# ):
# if code_text:
# codebox_res = api.codebox_chat("```"+code_text+"```", do_code_exe=True)
# st.toast(webui_configs['sandbox']['text_execing_code'],)
# else:
# st.toast(webui_configs['sandbox']['text_error_exec_code'],)
# #TODO 这段信息会被记录到history里
# if codebox_res is not None and codebox_res.code_exe_status != 200:
# st.toast(f"{codebox_res.code_exe_response}")
# if codebox_res is not None and codebox_res.code_exe_status == 200:
# st.toast(f"codebox_chat {codebox_res}")
# chat_box.ai_say(Markdown(code_text, in_expander=True, title="code interpreter", unsafe_allow_html=True), )
# if codebox_res.code_exe_type == "image/png":
# base_text = f"```\n{code_text}\n```\n\n"
# img_html = "<img src='data:image/png;base64,{}' class='img-fluid'>".format(
# codebox_res.code_exe_response
# )
# chat_box.update_msg(img_html, streaming=False, state="complete")
# else:
# chat_box.update_msg('```\n'+code_text+'\n```'+"\n\n"+'```\n'+codebox_res.code_exe_response+'\n```',
# streaming=False, state="complete")
now = datetime.now()
with st.sidebar:

View File

@ -14,7 +14,8 @@ from coagent.orm import table_init
from configs.model_config import (
KB_ROOT_PATH, kbs_config, DEFAULT_VS_TYPE, WEB_CRAWL_PATH,
EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict
EMBEDDING_DEVICE, EMBEDDING_ENGINE, EMBEDDING_MODEL, embedding_model_dict,
llm_model_dict
)
# SENTENCE_SIZE = 100
@ -136,6 +137,8 @@ def knowledge_page(
embed_engine=EMBEDDING_ENGINE,
embedding_device= EMBEDDING_DEVICE,
embed_model_path=embedding_model_dict[embed_model],
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
)
st.toast(ret.get("msg", " "))
st.session_state["selected_kb_name"] = kb_name
@ -160,7 +163,10 @@ def knowledge_page(
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True, "embed_model": EMBEDDING_MODEL,
"embed_model_path": embedding_model_dict[EMBEDDING_MODEL],
"model_device": EMBEDDING_DEVICE,
"embed_engine": EMBEDDING_ENGINE}
"embed_engine": EMBEDDING_ENGINE,
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],
}
for f in files]
data[-1]["not_refresh_vs_cache"]=False
for k in data:
@ -210,7 +216,9 @@ def knowledge_page(
"embed_model": EMBEDDING_MODEL,
"embed_model_path": embedding_model_dict[EMBEDDING_MODEL],
"model_device": EMBEDDING_DEVICE,
"embed_engine": EMBEDDING_ENGINE}]
"embed_engine": EMBEDDING_ENGINE,
"api_key": llm_model_dict[LLM_MODEL]["api_key"],
"api_base_url": llm_model_dict[LLM_MODEL]["api_base_url"],}]
for k in data:
ret = api.upload_kb_doc(**k)
logger.info(ret)
@ -297,7 +305,9 @@ def knowledge_page(
api.update_kb_doc(kb, row["file_name"],
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE
model_device=EMBEDDING_DEVICE,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
)
st.experimental_rerun()
@ -311,7 +321,9 @@ def knowledge_page(
api.delete_kb_doc(kb, row["file_name"],
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE)
model_device=EMBEDDING_DEVICE,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
st.experimental_rerun()
if cols[3].button(
@ -323,7 +335,9 @@ def knowledge_page(
ret = api.delete_kb_doc(kb, row["file_name"], True,
embed_engine=EMBEDDING_ENGINE,embed_model=EMBEDDING_MODEL,
embed_model_path=embedding_model_dict[EMBEDDING_MODEL],
model_device=EMBEDDING_DEVICE)
model_device=EMBEDDING_DEVICE,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],)
st.toast(ret.get("msg", " "))
st.experimental_rerun()
@ -344,6 +358,8 @@ def knowledge_page(
for d in api.recreate_vector_store(
kb, vs_type=default_vs_type, embed_model=embedding_model, embedding_device=EMBEDDING_DEVICE,
embed_model_path=embedding_model_dict["embedding_model"], embed_engine=EMBEDDING_ENGINE,
api_key=llm_model_dict[LLM_MODEL]["api_key"],
api_base_url=llm_model_dict[LLM_MODEL]["api_base_url"],
):
if msg := check_error_msg(d):
st.toast(msg)

View File

@ -299,7 +299,9 @@ class ApiRequest:
stream: bool = True,
no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2
llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
对应api.py/chat/chat接口
@ -311,8 +313,8 @@ class ApiRequest:
"query": query,
"history": history,
"stream": stream,
"api_key": os.environ["OPENAI_API_KEY"],
"api_base_url": os.environ["API_BASE_URL"],
"api_key": api_key,
"api_base_url": api_base_url,
"embed_model": embed_model,
"embed_model_path": embed_model_path,
"embed_engine": embed_engine,
@ -339,7 +341,9 @@ class ApiRequest:
stream: bool = True,
no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2
llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
对应api.py/chat/knowledge_base_chat接口
@ -355,8 +359,8 @@ class ApiRequest:
"history": history,
"stream": stream,
"local_doc_url": no_remote_api,
"api_key": os.environ["OPENAI_API_KEY"],
"api_base_url": os.environ["API_BASE_URL"],
"api_key": api_key,
"api_base_url": api_base_url,
"embed_model": embed_model,
"embed_model_path": embed_model_path,
"embed_engine": embed_engine,
@ -386,7 +390,10 @@ class ApiRequest:
stream: bool = True,
no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2
llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
对应api.py/chat/search_engine_chat接口
@ -400,8 +407,8 @@ class ApiRequest:
"top_k": top_k,
"history": history,
"stream": stream,
"api_key": os.environ["OPENAI_API_KEY"],
"api_base_url": os.environ["API_BASE_URL"],
"api_key": api_key,
"api_base_url": api_base_url,
"embed_model": embed_model,
"embed_model_path": embed_model_path,
"embed_engine": embed_engine,
@ -432,7 +439,9 @@ class ApiRequest:
stream: bool = True,
no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", model_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2
llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
对应api.py/chat/knowledge_base_chat接口
@ -458,8 +467,8 @@ class ApiRequest:
"cb_search_type": cb_search_type,
"stream": stream,
"local_doc_url": no_remote_api,
"api_key": os.environ["OPENAI_API_KEY"],
"api_base_url": os.environ["API_BASE_URL"],
"api_key": api_key,
"api_base_url": api_base_url,
"embed_model": embed_model,
"embed_model_path": embed_model_path,
"embed_engine": embed_engine,
@ -510,6 +519,8 @@ class ApiRequest:
embed_model: str="", embed_model_path: str="",
model_device: str="", embed_engine: str="",
temperature: float=0.2, model_name:str ="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
对应api.py/chat/chat接口
@ -541,8 +552,8 @@ class ApiRequest:
"isDetailed": isDetailed,
"upload_file": upload_file,
"kb_root_path": kb_root_path,
"api_key": os.environ["OPENAI_API_KEY"],
"api_base_url": os.environ["API_BASE_URL"],
"api_key": api_key,
"api_base_url": api_base_url,
"embed_model": embed_model,
"embed_model_path": embed_model_path,
"embed_engine": embed_engine,
@ -588,6 +599,8 @@ class ApiRequest:
embed_model: str="", embed_model_path: str="",
model_device: str="", embed_engine: str="",
temperature: float=0.2, model_name: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
对应api.py/chat/chat接口
@ -620,8 +633,8 @@ class ApiRequest:
"isDetailed": isDetailed,
"upload_file": upload_file,
"kb_root_path": kb_root_path,
"api_key": os.environ["OPENAI_API_KEY"],
"api_base_url": os.environ["API_BASE_URL"],
"api_key": api_key,
"api_base_url": api_base_url,
"embed_model": embed_model,
"embed_model_path": embed_model_path,
"embed_engine": embed_engine,
@ -694,7 +707,9 @@ class ApiRequest:
no_remote_api: bool = None,
kb_root_path: str =KB_ROOT_PATH,
embed_model: str="", embed_model_path: str="",
embedding_device: str="", embed_engine: str=""
embedding_device: str="", embed_engine: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
对应api.py/knowledge_base/create_knowledge_base接口
@ -706,8 +721,8 @@ class ApiRequest:
"knowledge_base_name": knowledge_base_name,
"vector_store_type": vector_store_type,
"kb_root_path": kb_root_path,
"api_key": os.environ["OPENAI_API_KEY"],
"api_base_url": os.environ["API_BASE_URL"],
"api_key": api_key,
"api_base_url": api_base_url,
"embed_model": embed_model,
"embed_model_path": embed_model_path,
"model_device": embedding_device,
@ -781,7 +796,9 @@ class ApiRequest:
no_remote_api: bool = None,
kb_root_path: str = KB_ROOT_PATH,
embed_model: str="", embed_model_path: str="",
model_device: str="", embed_engine: str=""
model_device: str="", embed_engine: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
对应api.py/knowledge_base/upload_docs接口
@ -810,8 +827,8 @@ class ApiRequest:
override,
not_refresh_vs_cache,
kb_root_path=kb_root_path,
api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"],
api_key=api_key,
api_base_url=api_base_url,
embed_model=embed_model,
embed_model_path=embed_model_path,
model_device=model_device,
@ -839,7 +856,9 @@ class ApiRequest:
no_remote_api: bool = None,
kb_root_path: str = KB_ROOT_PATH,
embed_model: str="", embed_model_path: str="",
model_device: str="", embed_engine: str=""
model_device: str="", embed_engine: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
对应api.py/knowledge_base/delete_doc接口
@ -853,8 +872,8 @@ class ApiRequest:
"delete_content": delete_content,
"not_refresh_vs_cache": not_refresh_vs_cache,
"kb_root_path": kb_root_path,
"api_key": os.environ["OPENAI_API_KEY"],
"api_base_url": os.environ["API_BASE_URL"],
"api_key": api_key,
"api_base_url": api_base_url,
"embed_model": embed_model,
"embed_model_path": embed_model_path,
"model_device": model_device,
@ -878,7 +897,9 @@ class ApiRequest:
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="",
model_device: str="", embed_engine: str=""
model_device: str="", embed_engine: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
对应api.py/knowledge_base/update_doc接口
@ -889,8 +910,8 @@ class ApiRequest:
if no_remote_api:
response = run_async(update_doc(
knowledge_base_name, file_name, not_refresh_vs_cache, kb_root_path=KB_ROOT_PATH,
api_key=os.environ["OPENAI_API_KEY"],
api_base_url=os.environ["API_BASE_URL"],
api_key=api_key,
api_base_url=api_base_url,
embed_model=embed_model,
embed_model_path=embed_model_path,
model_device=model_device,
@ -915,7 +936,9 @@ class ApiRequest:
no_remote_api: bool = None,
kb_root_path: str =KB_ROOT_PATH,
embed_model: str="", embed_model_path: str="",
embedding_device: str="", embed_engine: str=""
embedding_device: str="", embed_engine: str="",
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
对应api.py/knowledge_base/recreate_vector_store接口
@ -928,8 +951,8 @@ class ApiRequest:
"allow_empty_kb": allow_empty_kb,
"vs_type": vs_type,
"kb_root_path": kb_root_path,
"api_key": os.environ["OPENAI_API_KEY"],
"api_base_url": os.environ["API_BASE_URL"],
"api_key": api_key,
"api_base_url": api_base_url,
"embed_model": embed_model,
"embed_model_path": embed_model_path,
"model_device": embedding_device,
@ -1041,7 +1064,9 @@ class ApiRequest:
# code base 相关操作
def create_code_base(self, cb_name, zip_file, do_interpret: bool, no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2
llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
创建 code_base
@ -1067,8 +1092,8 @@ class ApiRequest:
"cb_name": cb_name,
"code_path": raw_code_path,
"do_interpret": do_interpret,
"api_key": os.environ["OPENAI_API_KEY"],
"api_base_url": os.environ["API_BASE_URL"],
"api_key": api_key,
"api_base_url": api_base_url,
"embed_model": embed_model,
"embed_model_path": embed_model_path,
"embed_engine": embed_engine,
@ -1091,7 +1116,9 @@ class ApiRequest:
def delete_code_base(self, cb_name: str, no_remote_api: bool = None,
embed_model: str="", embed_model_path: str="", embedding_device: str="", embed_engine: str="",
llm_model: str ="", temperature: float= 0.2
llm_model: str ="", temperature: float= 0.2,
api_key: str=os.environ["OPENAI_API_KEY"],
api_base_url: str = os.environ["API_BASE_URL"],
):
'''
删除 code_base
@ -1102,8 +1129,8 @@ class ApiRequest:
no_remote_api = self.no_remote_api
data = {
"cb_name": cb_name,
"api_key": os.environ["OPENAI_API_KEY"],
"api_base_url": os.environ["API_BASE_URL"],
"api_key": api_key,
"api_base_url": api_base_url,
"embed_model": embed_model,
"embed_model_path": embed_model_path,
"embed_engine": embed_engine,