412 lines
17 KiB
Python
412 lines
17 KiB
Python
|
import time, os, docker, requests, json, uuid, subprocess, time, asyncio, aiohttp, re, traceback
|
||
|
import psutil
|
||
|
from typing import List, Optional, Union
|
||
|
from pathlib import Path
|
||
|
from loguru import logger
|
||
|
|
||
|
from websockets.sync.client import connect as ws_connect_sync
|
||
|
from websockets.client import connect as ws_connect
|
||
|
from websocket import create_connection
|
||
|
from websockets.client import WebSocketClientProtocol, ClientConnection
|
||
|
from websockets.exceptions import ConnectionClosedError
|
||
|
|
||
|
from configs.server_config import SANDBOX_SERVER
|
||
|
from .basebox import BaseBox, CodeBoxResponse, CodeBoxStatus, CodeBoxFile
|
||
|
|
||
|
|
||
|
class PyCodeBox(BaseBox):
|
||
|
|
||
|
enter_status: bool = False
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
remote_url: str = "",
|
||
|
remote_ip: str = SANDBOX_SERVER["host"],
|
||
|
remote_port: str = SANDBOX_SERVER["port"],
|
||
|
token: str = "mytoken",
|
||
|
do_code_exe: bool = False,
|
||
|
do_remote: bool = False
|
||
|
):
|
||
|
super().__init__(remote_url, remote_ip, remote_port, token, do_code_exe, do_remote)
|
||
|
self.enter_status = True
|
||
|
asyncio.run(self.astart())
|
||
|
|
||
|
def decode_code_from_text(self, text: str) -> str:
|
||
|
pattern = r'```.*?```'
|
||
|
code_blocks = re.findall(pattern, text, re.DOTALL)
|
||
|
code_text: str = "\n".join([block.strip('`') for block in code_blocks])
|
||
|
code_text = code_text[6:] if code_text.startswith("python") else code_text
|
||
|
code_text = code_text.replace("python\n", "").replace("code", "")
|
||
|
return code_text
|
||
|
|
||
|
def run(
|
||
|
self, code_text: Optional[str] = None,
|
||
|
file_path: Optional[os.PathLike] = None,
|
||
|
retry = 3,
|
||
|
) -> CodeBoxResponse:
|
||
|
if not code_text and not file_path:
|
||
|
return CodeBoxResponse(
|
||
|
code_exe_response="Code or file_path must be specifieds!",
|
||
|
code_text=code_text,
|
||
|
code_exe_type="text",
|
||
|
code_exe_status=502,
|
||
|
do_code_exe=self.do_code_exe,
|
||
|
)
|
||
|
|
||
|
if code_text and file_path:
|
||
|
return CodeBoxResponse(
|
||
|
code_exe_response="Can only specify code or the file to read_from!",
|
||
|
code_text=code_text,
|
||
|
code_exe_type="text",
|
||
|
code_exe_status=502,
|
||
|
do_code_exe=self.do_code_exe,
|
||
|
)
|
||
|
|
||
|
if file_path:
|
||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||
|
code_text = f.read()
|
||
|
|
||
|
# run code in jupyter kernel
|
||
|
if retry <= 0:
|
||
|
raise RuntimeError("Could not connect to kernel")
|
||
|
|
||
|
if not self.ws:
|
||
|
raise RuntimeError("Jupyter not running. Make sure to start it first")
|
||
|
|
||
|
logger.debug(f"code_text: {json.dumps(code_text, ensure_ascii=False)}")
|
||
|
self.ws.send(
|
||
|
json.dumps(
|
||
|
{
|
||
|
"header": {
|
||
|
"msg_id": (msg_id := uuid.uuid4().hex),
|
||
|
"msg_type": "execute_request",
|
||
|
},
|
||
|
"parent_header": {},
|
||
|
"metadata": {},
|
||
|
"content": {
|
||
|
"code": code_text,
|
||
|
"silent": True,
|
||
|
"store_history": True,
|
||
|
"user_expressions": {},
|
||
|
"allow_stdin": False,
|
||
|
"stop_on_error": True,
|
||
|
},
|
||
|
"channel": "shell",
|
||
|
"buffers": [],
|
||
|
}
|
||
|
)
|
||
|
)
|
||
|
result = ""
|
||
|
while True:
|
||
|
try:
|
||
|
if isinstance(self.ws, WebSocketClientProtocol):
|
||
|
raise RuntimeError("Mixing asyncio and sync code is not supported")
|
||
|
received_msg = json.loads(self.ws.recv())
|
||
|
except ConnectionClosedError:
|
||
|
logger.debug("box start, ConnectionClosedError!!!")
|
||
|
self.start()
|
||
|
return self.run(code_text, file_path, retry - 1)
|
||
|
|
||
|
if (
|
||
|
received_msg["header"]["msg_type"] == "stream"
|
||
|
and received_msg["parent_header"]["msg_id"] == msg_id
|
||
|
):
|
||
|
msg = received_msg["content"]["text"].strip()
|
||
|
if "Requirement already satisfied:" in msg:
|
||
|
continue
|
||
|
result += msg + "\n"
|
||
|
|
||
|
elif (
|
||
|
received_msg["header"]["msg_type"] == "execute_result"
|
||
|
and received_msg["parent_header"]["msg_id"] == msg_id
|
||
|
):
|
||
|
result += received_msg["content"]["data"]["text/plain"].strip() + "\n"
|
||
|
|
||
|
elif received_msg["header"]["msg_type"] == "display_data":
|
||
|
if "image/png" in received_msg["content"]["data"]:
|
||
|
return CodeBoxResponse(
|
||
|
code_exe_type="image/png",
|
||
|
code_text=code_text,
|
||
|
code_exe_response=received_msg["content"]["data"]["image/png"],
|
||
|
code_exe_status=200,
|
||
|
do_code_exe=self.do_code_exe
|
||
|
)
|
||
|
if "text/plain" in received_msg["content"]["data"]:
|
||
|
return CodeBoxResponse(
|
||
|
code_exe_type="text",
|
||
|
code_text=code_text,
|
||
|
code_exe_response=received_msg["content"]["data"]["text/plain"],
|
||
|
code_exe_status=200,
|
||
|
do_code_exe=self.do_code_exe
|
||
|
)
|
||
|
return CodeBoxResponse(
|
||
|
code_exe_type="error",
|
||
|
code_text=code_text,
|
||
|
code_exe_response=received_msg["content"]["data"]["text/plain"],
|
||
|
code_exe_status=420,
|
||
|
do_code_exe=self.do_code_exe
|
||
|
)
|
||
|
elif (
|
||
|
received_msg["header"]["msg_type"] == "status"
|
||
|
and received_msg["parent_header"]["msg_id"] == msg_id
|
||
|
and received_msg["content"]["execution_state"] == "idle"
|
||
|
):
|
||
|
if len(result) > 500:
|
||
|
result = "[...]\n" + result[-500:]
|
||
|
return CodeBoxResponse(
|
||
|
code_exe_type="text",
|
||
|
code_text=code_text,
|
||
|
code_exe_response=result or "Code run successfully (no output)",
|
||
|
code_exe_status=200,
|
||
|
do_code_exe=self.do_code_exe
|
||
|
)
|
||
|
|
||
|
elif (
|
||
|
received_msg["header"]["msg_type"] == "error"
|
||
|
and received_msg["parent_header"]["msg_id"] == msg_id
|
||
|
):
|
||
|
error = (
|
||
|
f"{received_msg['content']['ename']}: "
|
||
|
f"{received_msg['content']['evalue']}"
|
||
|
)
|
||
|
return CodeBoxResponse(
|
||
|
code_exe_type="error",
|
||
|
code_text=code_text,
|
||
|
code_exe_response=error,
|
||
|
code_exe_status=500,
|
||
|
do_code_exe=self.do_code_exe
|
||
|
)
|
||
|
|
||
|
def _get_kernelid(self, ) -> None:
|
||
|
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||
|
response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers)
|
||
|
if len(response.json()) > 0:
|
||
|
self.kernel_id = response.json()[0]["id"]
|
||
|
else:
|
||
|
response = requests.post(f"{self.kernel_url}?token={self.token}", headers=headers)
|
||
|
self.kernel_id = response.json()["id"]
|
||
|
if self.kernel_id is None:
|
||
|
raise Exception("Could not start kernel")
|
||
|
|
||
|
async def _aget_kernelid(self, ) -> None:
|
||
|
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||
|
response = requests.get(f"{self.kernel_url}?token={self.token}", headers=headers)
|
||
|
async with aiohttp.ClientSession() as session:
|
||
|
async with session.get(f"{self.kernel_url}?token={self.token}", headers=headers) as resp:
|
||
|
if len(await resp.json()) > 0:
|
||
|
self.kernel_id = (await resp.json())[0]["id"]
|
||
|
else:
|
||
|
async with session.post(f"{self.kernel_url}?token={self.token}", headers=headers) as response:
|
||
|
self.kernel_id = (await response.json())["id"]
|
||
|
|
||
|
# if len(response.json()) > 0:
|
||
|
# self.kernel_id = response.json()[0]["id"]
|
||
|
# else:
|
||
|
# response = requests.post(f"{self.kernel_url}?token={self.token}", headers=headers)
|
||
|
# self.kernel_id = response.json()["id"]
|
||
|
# if self.kernel_id is None:
|
||
|
# raise Exception("Could not start kernel")
|
||
|
def _check_connect(self, ) -> bool:
|
||
|
if self.kernel_url == "":
|
||
|
return False
|
||
|
|
||
|
try:
|
||
|
response = requests.get(f"{self.kernel_url}?token={self.token}", timeout=270)
|
||
|
return response.status_code == 200
|
||
|
except requests.exceptions.ConnectionError:
|
||
|
return False
|
||
|
|
||
|
async def _acheck_connect(self, ) -> bool:
|
||
|
if self.kernel_url == "":
|
||
|
return False
|
||
|
|
||
|
try:
|
||
|
async with aiohttp.ClientSession() as session:
|
||
|
async with session.get(f"{self.kernel_url}?token={self.token}", timeout=270) as resp:
|
||
|
return resp.status == 200
|
||
|
except aiohttp.ClientConnectorError:
|
||
|
pass
|
||
|
except aiohttp.ServerDisconnectedError:
|
||
|
pass
|
||
|
|
||
|
def _check_port(self, ) -> bool:
|
||
|
try:
|
||
|
response = requests.get(f"http://localhost:{self.remote_port}", timeout=270)
|
||
|
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||
|
return response.status_code == 200
|
||
|
except requests.exceptions.ConnectionError:
|
||
|
return False
|
||
|
|
||
|
async def _acheck_port(self, ) -> bool:
|
||
|
try:
|
||
|
async with aiohttp.ClientSession() as session:
|
||
|
async with session.get(f"http://localhost:{self.remote_port}", timeout=270) as resp:
|
||
|
logger.warning(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||
|
return resp.status == 200
|
||
|
except aiohttp.ClientConnectorError:
|
||
|
pass
|
||
|
except aiohttp.ServerDisconnectedError:
|
||
|
pass
|
||
|
|
||
|
def _check_connect_success(self, retry_nums: int = 5) -> bool:
|
||
|
while retry_nums > 0:
|
||
|
try:
|
||
|
connect_status = self._check_connect()
|
||
|
if connect_status:
|
||
|
logger.info(f"{self.remote_url} connection success")
|
||
|
return True
|
||
|
except requests.exceptions.ConnectionError:
|
||
|
logger.info(f"{self.remote_url} connection fail")
|
||
|
retry_nums -= 1
|
||
|
time.sleep(5)
|
||
|
raise BaseException(f"can't connect to {self.remote_url}")
|
||
|
|
||
|
async def _acheck_connect_success(self, retry_nums: int = 5) -> bool:
|
||
|
while retry_nums > 0:
|
||
|
try:
|
||
|
connect_status = await self._acheck_connect()
|
||
|
if connect_status:
|
||
|
logger.info(f"{self.remote_url} connection success")
|
||
|
return True
|
||
|
except requests.exceptions.ConnectionError:
|
||
|
logger.info(f"{self.remote_url} connection fail")
|
||
|
retry_nums -= 1
|
||
|
time.sleep(5)
|
||
|
raise BaseException(f"can't connect to {self.remote_url}")
|
||
|
|
||
|
def start(self, ):
|
||
|
'''判断是从外部service执行还是内部启动notebook执行'''
|
||
|
self.jupyter = None
|
||
|
if self.do_remote:
|
||
|
# TODO自动检测日期,并重启容器
|
||
|
self.kernel_url = self.remote_url + "/api/kernels"
|
||
|
self._check_connect_success()
|
||
|
|
||
|
self._get_kernelid()
|
||
|
logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
||
|
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||
|
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||
|
self.ws = create_connection(self.wc_url, headers=headers)
|
||
|
else:
|
||
|
# TODO 自动检测本地接口
|
||
|
port_status = self._check_port()
|
||
|
connect_status = self._check_connect()
|
||
|
logger.debug(f"port_status: {port_status}, connect_status: {connect_status}")
|
||
|
if port_status and not connect_status:
|
||
|
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||
|
|
||
|
if not connect_status:
|
||
|
self.jupyter = subprocess.Popen(
|
||
|
[
|
||
|
"jupyer", "notebnook",
|
||
|
f"--NotebookApp.token={self.token}",
|
||
|
f"--port={self.remote_port}",
|
||
|
"--no-browser",
|
||
|
"--ServerApp.disable_check_xsrf=True",
|
||
|
],
|
||
|
stderr=subprocess.PIPE,
|
||
|
stdin=subprocess.PIPE,
|
||
|
stdout=subprocess.PIPE,
|
||
|
)
|
||
|
self.kernel_url = self.remote_url + "/api/kernels"
|
||
|
self._check_connect_success()
|
||
|
self._get_kernelid()
|
||
|
logger.debug(self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}")
|
||
|
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||
|
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||
|
self.ws = create_connection(self.wc_url, headers=headers)
|
||
|
|
||
|
async def astart(self, ):
|
||
|
'''判断是从外部service执行还是内部启动notebook执行'''
|
||
|
self.jupyter = None
|
||
|
if self.do_remote:
|
||
|
# TODO自动检测日期,并重启容器
|
||
|
self.kernel_url = self.remote_url + "/api/kernels"
|
||
|
|
||
|
await self._acheck_connect_success()
|
||
|
await self._aget_kernelid()
|
||
|
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||
|
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||
|
self.ws = create_connection(self.wc_url, headers=headers)
|
||
|
else:
|
||
|
# TODO 自动检测本地接口
|
||
|
port_status = await self._acheck_port()
|
||
|
self.kernel_url = self.remote_url + "/api/kernels"
|
||
|
connect_status = await self._acheck_connect()
|
||
|
logger.debug(f"port_status: {port_status}, connect_status: {connect_status}")
|
||
|
if port_status and not connect_status:
|
||
|
raise BaseException(f"Port is conflict, please check your codebox's port {self.remote_port}")
|
||
|
|
||
|
if not connect_status:
|
||
|
self.jupyter = subprocess.Popen(
|
||
|
[
|
||
|
"jupyter", "notebook",
|
||
|
f"--NotebookApp.token={self.token}",
|
||
|
f"--port={self.remote_port}",
|
||
|
"--no-browser",
|
||
|
"--ServerApp.disable_check_xsrf=True",
|
||
|
],
|
||
|
stderr=subprocess.PIPE,
|
||
|
stdin=subprocess.PIPE,
|
||
|
stdout=subprocess.PIPE,
|
||
|
)
|
||
|
self.kernel_url = self.remote_url + "/api/kernels"
|
||
|
await self._acheck_connect_success()
|
||
|
await self._aget_kernelid()
|
||
|
self.wc_url = self.kernel_url.replace("http", "ws") + f"/{self.kernel_id}/channels?token={self.token}"
|
||
|
headers = {"Authorization": f'Token {self.token}', 'token': self.token}
|
||
|
self.ws = create_connection(self.wc_url, headers=headers)
|
||
|
|
||
|
def status(self,) -> CodeBoxStatus:
|
||
|
if not self.kernel_id:
|
||
|
self._get_kernelid()
|
||
|
|
||
|
return CodeBoxStatus(
|
||
|
status="running" if self.kernel_id
|
||
|
and requests.get(self.kernel_url, timeout=270).status_code == 200
|
||
|
else "stopped"
|
||
|
)
|
||
|
|
||
|
async def astatus(self,) -> CodeBoxStatus:
|
||
|
if not self.kernel_id:
|
||
|
await self._aget_kernelid()
|
||
|
|
||
|
return CodeBoxStatus(
|
||
|
status="running" if self.kernel_id
|
||
|
and requests.get(self.kernel_url, timeout=270).status_code == 200
|
||
|
else "stopped"
|
||
|
)
|
||
|
|
||
|
def restart(self, ) -> CodeBoxStatus:
|
||
|
return CodeBoxStatus(status="restared")
|
||
|
|
||
|
def stop(self, ) -> CodeBoxStatus:
|
||
|
try:
|
||
|
if self.jupyter is not None:
|
||
|
for process in psutil.process_iter(["pid", "name", "cmdline"]):
|
||
|
# 检查进程名是否包含"jupyter"
|
||
|
if f'port={self.remote_port}' in str(process.info["cmdline"]).lower() and \
|
||
|
"jupyter" in process.info['name'].lower():
|
||
|
logger.warning(f'port={self.remote_port}, {process.info}')
|
||
|
# 关闭进程
|
||
|
process.terminate()
|
||
|
|
||
|
self.jupyter = None
|
||
|
except Exception as e:
|
||
|
logger.error(traceback.format_exc())
|
||
|
|
||
|
if self.ws is not None:
|
||
|
try:
|
||
|
if self.ws is not None:
|
||
|
self.ws.close()
|
||
|
else:
|
||
|
loop = asyncio.new_event_loop()
|
||
|
loop.run_until_complete(self.ws.close())
|
||
|
except Exception as e:
|
||
|
logger.error(traceback.format_exc())
|
||
|
self.ws = None
|
||
|
return CodeBoxStatus(status="stopped")
|
||
|
|
||
|
def __del__(self):
|
||
|
self.stop()
|