codefuse-chatbot/dev_opsgpt/chat/code_chat.py

144 lines
5.6 KiB
Python
Raw Normal View History

# encoding: utf-8
'''
@author: 温进
@file: code_chat.py
@time: 2023/10/24 下午4:04
@desc:
'''
from fastapi import Request, Body
import os, asyncio
from urllib.parse import urlencode
from typing import List
from fastapi.responses import StreamingResponse
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.prompts.chat import ChatPromptTemplate
from configs.model_config import (
llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CODE_PROMPT_TEMPLATE)
from dev_opsgpt.chat.utils import History, wrap_done
from dev_opsgpt.utils import BaseResponse
from .base_chat import Chat
from dev_opsgpt.llm_models import getChatModel
from dev_opsgpt.service.kb_api import search_docs, KBServiceFactory
from dev_opsgpt.service.cb_api import search_code, cb_exists_api
from loguru import logger
import json
class CodeChat(Chat):
def __init__(
self,
code_base_name: str = '',
code_limit: int = 1,
stream: bool = False,
request: Request = None,
) -> None:
super().__init__(engine_name=code_base_name, stream=stream)
self.engine_name = code_base_name
self.code_limit = code_limit
self.request = request
self.history_node_list = []
def check_service_status(self) -> BaseResponse:
cb = cb_exists_api(self.engine_name)
if not cb:
return BaseResponse(code=404, msg=f"未找到代码库 {self.engine_name}")
return BaseResponse(code=200, msg=f"找到代码库 {self.engine_name}")
def _process(self, query: str, history: List[History], model):
'''process'''
codes_res = search_code(query=query, cb_name=self.engine_name, code_limit=self.code_limit,
history_node_list=self.history_node_list)
codes = codes_res['related_code']
nodes = codes_res['related_node']
# update node names
node_names = [node[0] for node in nodes]
self.history_node_list.extend(node_names)
self.history_node_list = list(set(self.history_node_list))
context = "\n".join(codes)
source_nodes = []
for inum, node_info in enumerate(nodes[0:5]):
node_name, node_type, node_score = node_info[0], node_info[1], node_info[2]
source_nodes.append(f'{inum + 1}. 节点名为 {node_name}, 节点类型为 `{node_type}`, 节点得分为 `{node_score}`')
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", CODE_PROMPT_TEMPLATE)]
)
chain = LLMChain(prompt=chat_prompt, llm=model)
result = {"answer": "", "codes": source_nodes}
return chain, context, result
def chat(
self,
query: str = Body(..., description="用户输入", examples=["hello"]),
history: List[History] = Body(
[], description="历史对话",
examples=[[{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}]]
),
engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
code_limit: int = Body(1, examples=['1']),
stream: bool = Body(False, description="流式输出"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = Body(None),
**kargs
):
self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
self.code_limit = code_limit
self.stream = stream if isinstance(stream, bool) else stream.default
self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
self.request = request
return self._chat(query, history, **kargs)
def _chat(self, query: str, history: List[History], **kargs):
history = [History(**h) if isinstance(h, dict) else h for h in history]
service_status = self.check_service_status()
if service_status.code != 200: return service_status
def chat_iterator(query: str, history: List[History]):
model = getChatModel()
result, content = self.create_task(query, history, model, **kargs)
# logger.info('result={}'.format(result))
# logger.info('content={}'.format(content))
if self.stream:
for token in content["text"]:
result["answer"] = token
yield json.dumps(result, ensure_ascii=False)
else:
for token in content["text"]:
result["answer"] += token
yield json.dumps(result, ensure_ascii=False)
return StreamingResponse(chat_iterator(query, history),
media_type="text/event-stream")
def create_task(self, query: str, history: List[History], model):
'''构建 llm 生成任务'''
chain, context, result = self._process(query, history, model)
logger.info('chain={}'.format(chain))
try:
content = chain({"context": context, "question": query})
except Exception as e:
content = {"text": str(e)}
return result, content
def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
chain, context, result = self._process(query, history, model)
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}), callback.done
))
return task, result