codefuse-chatbot/coagent/chat/knowledge_chat.py

90 lines
4.0 KiB
Python

from fastapi import Request
import os, asyncio
from urllib.parse import urlencode
from typing import List
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.prompts.chat import ChatPromptTemplate
# from configs.model_config import (
# llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
# VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from coagent.base_configs.env_config import KB_ROOT_PATH
from coagent.connector.configs.prompts import ORIGIN_TEMPLATE_PROMPT
from coagent.chat.utils import History, wrap_done
from coagent.utils import BaseResponse
from coagent.llm_models.llm_config import LLMConfig, EmbedConfig
from .base_chat import Chat
from coagent.service.kb_api import search_docs, KBServiceFactory
from loguru import logger
class KnowledgeChat(Chat):
def __init__(
self,
engine_name: str = "",
top_k: int = 5,
stream: bool = False,
score_thresold: float = 1.0,
local_doc_url: bool = False,
request: Request = None,
kb_root_path: str = KB_ROOT_PATH,
) -> None:
super().__init__(engine_name, top_k, stream)
self.score_thresold = score_thresold
self.local_doc_url = local_doc_url
self.request = request
self.kb_root_path = kb_root_path
def check_service_status(self) -> BaseResponse:
kb = KBServiceFactory.get_service_by_name(self.engine_name, self.kb_root_path)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {self.engine_name}")
return BaseResponse(code=200, msg=f"找到知识库 {self.engine_name}")
def _process(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, ):
'''process'''
docs = search_docs(
query, self.engine_name, self.top_k, self.score_threshold, self.kb_root_path,
api_key=embed_config.api_key, api_base_url=embed_config.api_base_url, embed_model=embed_config.embed_model,
embed_model_path=embed_config.embed_model_path, embed_engine=embed_config.embed_engine,
model_device=embed_config.model_device,
)
context = "\n".join([doc.page_content for doc in docs])
source_documents = []
for inum, doc in enumerate(docs):
filename = os.path.split(doc.metadata["source"])[-1]
if self.local_doc_url:
url = "file://" + doc.metadata["source"]
else:
parameters = urlencode({"knowledge_base_name": self.engine_name, "file_name":filename})
url = f"{self.request.base_url}knowledge_base/download_doc?" + parameters
text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
source_documents.append(text)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", ORIGIN_TEMPLATE_PROMPT)]
)
chain = LLMChain(prompt=chat_prompt, llm=model)
result = {"answer": "", "docs": source_documents}
return chain, context, result
def create_task(self, query: str, history: List[History], model, llm_config: LLMConfig, embed_config: EmbedConfig, ):
'''构建 llm 生成任务'''
logger.debug(f"query: {query}, history: {history}")
chain, context, result = self._process(query, history, model, llm_config, embed_config)
try:
content = chain({"context": context, "question": query})
except Exception as e:
content = {"text": str(e)}
return result, content
def create_atask(self, query, history, model, llm_config: LLMConfig, embed_config: EmbedConfig, callback: AsyncIteratorCallbackHandler):
chain, context, result = self._process(query, history, model, llm_config, embed_config)
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}), callback.done
))
return task, result