codefuse-chatbot/dev_opsgpt/service/model_workers/minimax.py

173 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastchat.conversation import Conversation
from .base import *
from fastchat import conversation as conv
import sys
import json
# from server.utils import get_httpx_client
from typing import List, Dict
from configs import logger, log_verbose
class MiniMaxWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "embo-01"
def __init__(
self,
*,
model_names: List[str] = ["minimax-api"],
controller_addr: str = None,
worker_addr: str = None,
version: str = "abab5.5-chat",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
self.version = version
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
role_maps = {
"user": self.user_role,
"assistant": self.ai_role,
"system": "system",
}
messages = [{"sender_type": role_maps[x["role"]], "text": x["content"]} for x in messages]
return messages
def do_chat(self, params: ApiChatParams) -> Dict:
# 按照官网推荐直接调用abab 5.5模型
# TODO: 支持指定回复要求支持指定用户名称、AI名称
params.load_config(self.model_names[0])
url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
pro = "_pro" if params.is_pro else ""
headers = {
"Authorization": f"Bearer {params.api_key}",
"Content-Type": "application/json",
}
messages = self.validate_messages(params.messages)
data = {
"model": params.version,
"stream": True,
"mask_sensitive_info": True,
"messages": messages,
"temperature": params.temperature,
"top_p": params.top_p,
"tokens_to_generate": params.max_tokens or 1024,
# TODO: 以下参数为minimax特有传入空值会出错。
# "prompt": params.system_message or self.conv.system_message,
# "bot_setting": [],
# "role_meta": params.role_meta,
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:data: {data}')
logger.info(f'{self.__class__.__name__}:url: {url.format(pro=pro, group_id=params.group_id)}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
with get_httpx_client() as client:
response = client.stream("POST",
url.format(pro=pro, group_id=params.group_id),
headers=headers,
json=data)
with response as r:
text = ""
for e in r.iter_text():
if not e.startswith("data: "): # 真是优秀的返回
data = {
"error_code": 500,
"text": f"minimax返回错误的结果{e}",
"error": {
"message": f"minimax返回错误的结果{e}",
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
yield data
continue
data = json.loads(e[6:])
if data.get("usage"):
break
if choices := data.get("choices"):
if chunk := choices[0].get("delta", ""):
text += chunk
yield {"error_code": 0, "text": text}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
params.load_config(self.model_names[0])
url = f"https://api.minimax.chat/v1/embeddings?GroupId={params.group_id}"
headers = {
"Authorization": f"Bearer {params.api_key}",
"Content-Type": "application/json",
}
data = {
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
"texts": [],
"type": "query" if params.to_query else "db",
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:data: {data}')
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
with get_httpx_client() as client:
result = []
i = 0
batch_size = 10
while i < len(params.texts):
texts = params.texts[i:i+batch_size]
data["texts"] = texts
r = client.post(url, headers=headers, json=data).json()
if embeddings := r.get("vectors"):
result += embeddings
elif error := r.get("base_resp"):
data = {
"code": error["status_code"],
"msg": error["status_msg"],
"error": {
"message": error["status_msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data
i += batch_size
return {"code": 200, "data": embeddings}
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation(
name=self.model_names[0],
system_message="你是MiniMax自主研发的大型语言模型回答问题简洁有条理。",
messages=[],
roles=["USER", "BOT"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = MiniMaxWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21002",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21002)