codefuse-chatbot/dev_opsgpt/sandbox/pycodebox.py

427 lines
18 KiB
Python
Raw Permalink Normal View History

2023-09-28 10:58:58 +08:00
import time, os, docker, requests, json, uuid, subprocess, time, asyncio, aiohttp, re, traceback
import psutil
from typing import List, Optional, Union
from loguru import logger
from websockets.sync.client import connect as ws_connect_sync
from websockets.client import connect as ws_connect
from websocket import create_connection
from websockets.client import WebSocketClientProtocol, ClientConnection
from websockets.exceptions import ConnectionClosedError
from configs.server_config import SANDBOX_SERVER
from configs.model_config import JUPYTER_WORK_PATH
from .basebox import BaseBox, CodeBoxResponse, CodeBoxStatus
2023-09-28 10:58:58 +08:00
class PyCodeBox(BaseBox):
enter_status: bool = False
def __init__(
self,
remote_url: str = "",
remote_ip: str = SANDBOX_SERVER["host"],
remote_port: str = SANDBOX_SERVER["port"],
token: str = "mytoken",
do_code_exe: bool = False,
do_remote: bool = False,
do_check_net: bool = True,
2023-09-28 10:58:58 +08:00
):
super().__init__(remote_url, remote_ip, remote_port, token, do_code_exe, do_remote)
self.enter_status = True
self.do_check_net = do_check_net
2023-09-28 10:58:58 +08:00
asyncio.run(self.astart())
# logger.info(f"""remote_url: {self.remote_url},
# remote_ip: {self.remote_ip},
# remote_port: {self.remote_port}""")
2023-09-28 10:58:58 +08:00
def decode_code_from_text(self, text: str) -> str:
pattern = r'```.*?```'
code_blocks = re.findall(pattern, text, re.DOTALL)
code_text: str = "\n".join([block.strip('`') for block in code_blocks])
code_text = code_text[6:] if code_text.startswith("python") else code_text
code_text = code_text.replace("python\n", "").replace("code", "")
return code_text
def run(
self, code_text: Optional[str] = None,
file_path: Optional[os.PathLike] = None,
retry = 3,
) -> CodeBoxResponse:
if not code_text and not file_path:
return CodeBoxResponse(
code_exe_response="Code or file_path must be specifieds!",
code_text=code_text,
code_exe_type="text",
code_exe_status=502,
do_code_exe=self.do_code_exe,
)
if code_text and file_path:
return CodeBoxResponse(
code_exe_response="Can only specify code or the file to read_from!",
code_text=code_text,
code_exe_type="text",
code_exe_status=502,
do_code_exe=self.do_code_exe,
)
if file_path:
with open(file_path, "r", encoding="utf-8") as f:
code_text = f.read()
# run code in jupyter kernel
if retry <= 0:
raise RuntimeError("Could not connect to kernel")
if not self.ws:
raise RuntimeError("Jupyter not running. Make sure to start it first")
# logger.debug(f"code_text: {len(code_text)}, {code_text}")
2023-09-28 10:58:58 +08:00
self.ws.send(
json.dumps(
{
"header": {
"msg_id": (msg_id := uuid.uuid4().hex),
"msg_type": "execute_request",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code_text,
"silent": False, # True则内核会执行代码但不会发送执行结果如输出
"store_history": True, # True则执行的代码会被记录在交互式环境的历史记录中
2023-09-28 10:58:58 +08:00
"user_expressions": {},
"allow_stdin": False, # True允许代码执行时接受用户输入
"stop_on_error": True, # True当执行中遇到错误时后续代码将不会继续执行。
2023-09-28 10:58:58 +08:00
},
"channel": "shell",
"buffers": [],
}
)
)
result = ""
while True:
try:
if isinstance(self.ws, WebSocketClientProtocol):
raise RuntimeError("Mixing asyncio and sync code is not supported")
received_msg = json.loads(self.ws.recv())
except ConnectionClosedError:
# logger.debug("box start, ConnectionClosedError!!!")
2023-09-28 10:58:58 +08:00
self.start()
return self.run(code_text, file_path, retry - 1)
if (
received_msg["header"]["msg_type"] == "stream"
and received_msg["parent_header"]["msg_id"] == msg_id
):
msg = received_msg["content"]["text"].strip()
if "Requirement already satisfied:" in msg:
continue
result += msg + "\n"
elif (
received_msg["header"]["msg_type"] == "execute_result"
and received_msg["parent_header"]["msg_id"] == msg_id
):
result += received_msg["content"]["data"]["text/plain"].strip() + "\n"
elif received_msg["header"]["msg_type"] == "display_data":
if "image/png" in received_msg["content"]["data"]:
return CodeBoxResponse(
code_exe_type="image/png",
code_text=code_text,
code_exe_response=received_msg["content"]["data"]["image/png"],
code_exe_status=200,
do_code_exe=self.do_code_exe
)
if "text/plain" in received_msg["content"]["data"]:
return CodeBoxResponse(
code_exe_type="text",
code_text=code_text,
code_exe_response=received_msg["content"]["data"]["text/plain"],
code_exe_status=200,
do_code_exe=self.do_code_exe
)
return CodeBoxResponse(
code_exe_type="error",
code_text=code_text,
code_exe_response=received_msg["content"]["data"]["text/plain"],
code_exe_status=420,
do_code_exe=self.do_code_exe
)
elif (
received_msg["header"]["msg_type"] == "status"
and received_msg["parent_header"]["msg_id"] == msg_id
and received_msg["content"]["execution_state"] == "idle"
):
if len(result) > 500:
result = "[...]\n" + result[-500:]
return CodeBoxResponse(
code_exe_type="text",
code_text=code_text,
code_exe_response=result or "Code run successfully (no output)",
2023-09-28 10:58:58 +08:00
code_exe_status=200,
do_code_exe=self.do_code_exe
)
elif (
received_msg["header"]["msg_type"] == "error"
and received_msg["parent_header"]["msg_id"] == msg_id
):
error = (
f"{received_msg['content']['ename']}: "
f"{received_msg['content']['evalue']}"
)
return CodeBoxResponse(
code_exe_type="error",
code_text=code_text,
code_exe_response=error,
code_exe_status=500,
do_code_exe=self.do_code_exe
)
def _get_kernelid(self, ) -> None:
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers)
if len(response.json()) > 0:
self.kernel_id = response.json()[0]["id"]
else:
response = requests.post(f"{self.kernel_url}?token={self.token}", headers=headers)
self.kernel_id = response.json()["id"]
if self.kernel_id is None:
raise Exception("Could not start kernel")
async def _aget_kernelid(self, ) -> None:
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers)
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers) as resp:
if len(await resp.json()) > 0:
self.kernel_id = (await resp.json())[0]["id"]
else:
async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers) as response:
self.kernel_id = (await response.json())["id"]
# if len(response.json()) > 0:
# self.kernel_id = response.json()[0]["id"]
# else:
# response = requests.post(f"{self.kernel_url}?token={self.token}", headers=headers)
# self.kernel_id = response.json()["id"]
# if self.kernel_id is None:
# raise Exception("Could not start kernel")
def _check_connect(self, ) -> bool:
if self.kernel_url == "":
return False
try:
response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=270)
return response.status_code == 200
except requests.exceptions.ConnectionError:
return False
async def _acheck_connect(self, ) -> bool:
if self.kernel_url == "":
return False
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.kernel_url}?token={self.token}", timeout=270) as resp:
return resp.status == 200
except aiohttp.ClientConnectorError:
pass
except aiohttp.ServerDisconnectedError:
pass
def _check_port(self, ) -> bool:
try:
response = requests.get(f"{self.remote_ip}:{self.remote_port}", timeout=270)
2023-09-28 10:58:58 +08:00
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
return response.status_code == 200
except requests.exceptions.ConnectionError:
return False
async def _acheck_port(self, ) -> bool:
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.remote_ip}:{self.remote_port}", timeout=270) as resp:
2023-09-28 10:58:58 +08:00
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
return resp.status == 200
except aiohttp.ClientConnectorError:
pass
except aiohttp.ServerDisconnectedError:
pass
def _check_connect_success(self, retry_nums: int = 5) -> bool:
if not self.do_check_net: return True
2023-09-28 10:58:58 +08:00
while retry_nums > 0:
try:
connect_status = self._check_connect()
if connect_status:
logger.info(f"{self.remote_url} connection success")
return True
except requests.exceptions.ConnectionError:
logger.info(f"{self.remote_url} connection fail")
retry_nums -= 1
time.sleep(5)
raise BaseException(f"can't connect to {self.remote_url}")
async def _acheck_connect_success(self, retry_nums: int = 5) -> bool:
if not self.do_check_net: return True
2023-09-28 10:58:58 +08:00
while retry_nums > 0:
try:
connect_status = await self._acheck_connect()
if connect_status:
logger.info(f"{self.remote_url} connection success")
return True
except requests.exceptions.ConnectionError:
logger.info(f"{self.remote_url} connection fail")
retry_nums -= 1
time.sleep(5)
raise BaseException(f"can't connect to {self.remote_url}")
def start(self, ):
'''判断是从外部service执行还是内部启动notebook执行'''
self.jupyter = None
if self.do_remote:
# TODO自动检测日期,并重启容器
self.kernel_url = self.remote_url + "/api/kernels"
self._check_connect_success()
self._get_kernelid()
# logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
2023-09-28 10:58:58 +08:00
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
self.ws = create_connection(self.wc_url, headers=headers)
else:
# TODO 自动检测本地接口
port_status = self._check_port()
connect_status = self._check_connect()
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
2023-09-28 10:58:58 +08:00
if port_status and not connect_status:
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}")
if not connect_status:
self.jupyter = subprocess.run(
2023-09-28 10:58:58 +08:00
[
"jupyer", "notebnook",
f"--NotebookApp.token={self.token}",
f"--port={self.remote_port}",
"--no-browser",
"--ServerApp.disable_check_xsrf=True",
"--notebook-dir={JUPYTER_WORK_PATH}"
2023-09-28 10:58:58 +08:00
],
stderr=subprocess.PIPE,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
)
2023-09-28 10:58:58 +08:00
self.kernel_url = self.remote_url + "/api/kernels"
self.do_check_net = True
2023-09-28 10:58:58 +08:00
self._check_connect_success()
self._get_kernelid()
# logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
2023-09-28 10:58:58 +08:00
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
self.ws = create_connection(self.wc_url, headers=headers)
async def astart(self, ):
'''判断是从外部service执行还是内部启动notebook执行'''
self.jupyter = None
if self.do_remote:
# TODO自动检测日期,并重启容器
self.kernel_url = self.remote_url + "/api/kernels"
await self._acheck_connect_success()
await self._aget_kernelid()
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
self.ws = create_connection(self.wc_url, headers=headers)
else:
# TODO 自动检测本地接口
port_status = await self._acheck_port()
self.kernel_url = self.remote_url + "/api/kernels"
connect_status = await self._acheck_connect()
logger.info(f"port_status: {port_status}, connect_status: {connect_status}")
2023-09-28 10:58:58 +08:00
if port_status and not connect_status:
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}")
2023-09-28 10:58:58 +08:00
if not connect_status:
self.jupyter = subprocess.Popen(
[
"jupyter", "notebook",
f"--NotebookApp.token={self.token}",
f"--port={self.remote_port}",
"--no-browser",
"--ServerApp.disable_check_xsrf=True"
2023-09-28 10:58:58 +08:00
],
stderr=subprocess.PIPE,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
)
2023-09-28 10:58:58 +08:00
self.kernel_url = self.remote_url + "/api/kernels"
self.do_check_net = True
2023-09-28 10:58:58 +08:00
await self._acheck_connect_success()
await self._aget_kernelid()
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
self.ws = create_connection(self.wc_url, headers=headers)
def status(self,) -> CodeBoxStatus:
if not self.kernel_id:
self._get_kernelid()
return CodeBoxStatus(
status="running" if self.kernel_id
and requests.get(self.kernel_url, timeout=270).status_code == 200
else "stopped"
)
async def astatus(self,) -> CodeBoxStatus:
if not self.kernel_id:
await self._aget_kernelid()
return CodeBoxStatus(
status="running" if self.kernel_id
and requests.get(self.kernel_url, timeout=270).status_code == 200
else "stopped"
)
def restart(self, ) -> CodeBoxStatus:
return CodeBoxStatus(status="restared")
def stop(self, ) -> CodeBoxStatus:
try:
if self.jupyter is not None:
for process in psutil.process_iter(["pid", "name", "cmdline"]):
# 检查进程名是否包含"jupyter"
if f'port={self.remote_port}' in str(process.info["cmdline"]).lower() and \
"jupyter" in process.info['name'].lower():
logger.warning(f'port={self.remote_port}, {process.info}')
# 关闭进程
process.terminate()
self.jupyter = None
except Exception as e:
logger.error(traceback.format_exc())
if self.ws is not None:
try:
if self.ws is not None:
self.ws.close()
else:
loop = asyncio.new_event_loop()
loop.run_until_complete(self.ws.close())
except Exception as e:
logger.error(traceback.format_exc())
self.ws = None
# return CodeBoxStatus(status="stopped")
2023-09-28 10:58:58 +08:00
def __del__(self):
self.stop()