codefuse-chatbot/dev_opsgpt/tools/stock_tool.py

189 lines
8.1 KiB
Python
Raw Permalink Normal View History

import json
import os
import re
from pydantic import BaseModel, Field
from typing import List, Dict, Optional
import requests
import numpy as np
from loguru import logger
from .base_tool import BaseToolModel
from dev_opsgpt.utils.common_utils import read_json_file
cur_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
stock_infos = read_json_file(os.path.join(cur_dir, "../sources/tool_datas/stock.json"))
stock_dict = {i["mc"]: i["jys"]+i["dm"] for i in stock_infos}
class StockName(BaseToolModel):
"""
Tips
"""
name: str = "StockName"
description: str = "通过股票名称查询股票代码"
class ToolInputArgs(BaseModel):
"""Input for StockName"""
stock_name: int = Field(..., description="股票名称")
class ToolOutputArgs(BaseModel):
"""Output for StockName"""
stock_code: str = Field(..., description="股票代码")
@staticmethod
def run(stock_name: str):
return stock_dict.get(stock_name, "no stock_code")
class StockInfo(BaseToolModel):
"""
用于查询股票市场数据的StockInfo工具
"""
name: str = "StockInfo"
description: str = "根据提供的股票代码、日期范围和数据频率提供股票市场数据。"
class ToolInputArgs(BaseModel):
"""StockInfo的输入参数。"""
code: str = Field(..., description="要查询的股票代码,格式为'marketcode'")
end_date: Optional[str] = Field(default="", description="数据查询的结束日期。留空则为当前日期。如果没有提供结束日期,就留空")
count: int = Field(default=10, description="返回数据点的数量。")
frequency: str = Field(default='1d', description="数据点的频率,例如,'1d'表示每日,'1w'表示每周,'1M'表示每月,'1m'表示每分钟等。")
class ToolOutputArgs(BaseModel):
"""StockInfo的输出参数。"""
data: dict = Field(default=None, description="查询到的股票市场数据。")
@staticmethod
def run(code: str, count: int, frequency: str, end_date: Optional[str]="") -> "ToolOutputArgs":
"""执行股票数据查询工具。"""
# 该方法封装了调用底层股票数据API的逻辑并将结果格式化为pandas DataFrame。
try:
df = get_price(code, end_date=end_date, count=count, frequency=frequency)
# 将DataFrame转换为输出的字典格式
data = df.reset_index().to_dict(orient='list') # 将dataframe转换为字典列表
return StockInfo.ToolOutputArgs(data=data)
except Exception as e:
logger.exception("获取股票数据时发生错误。")
return e
#-*- coding:utf-8 -*- --------------Ashare 股票行情数据双核心版( https://github.com/mpquant/Ashare )
import json,requests,datetime
import pandas as pd #
#腾讯日线
def get_price_day_tx(code, end_date='', count=10, frequency='1d'): #日线获取
unit='week' if frequency in '1w' else 'month' if frequency in '1M' else 'day' #判断日线,周线,月线
if end_date: end_date=end_date.strftime('%Y-%m-%d') if isinstance(end_date,datetime.date) else end_date.split(' ')[0]
end_date='' if end_date==datetime.datetime.now().strftime('%Y-%m-%d') else end_date #如果日期今天就变成空
URL=f'http://web.ifzq.gtimg.cn/appstock/app/fqkline/get?param={code},{unit},,{end_date},{count},qfq'
st= json.loads(requests.get(URL).content); ms='qfq'+unit; stk=st['data'][code]
buf=stk[ms] if ms in stk else stk[unit] #指数返回不是qfqday,是day
df=pd.DataFrame(buf,columns=['time','open','close','high','low','volume'],dtype='float')
df.time=pd.to_datetime(df.time); df.set_index(['time'], inplace=True); df.index.name='' #处理索引
return df
#腾讯分钟线
def get_price_min_tx(code, end_date=None, count=10, frequency='1d'): #分钟线获取
ts=int(frequency[:-1]) if frequency[:-1].isdigit() else 1 #解析K线周期数
if end_date: end_date=end_date.strftime('%Y-%m-%d') if isinstance(end_date,datetime.date) else end_date.split(' ')[0]
URL=f'http://ifzq.gtimg.cn/appstock/app/kline/mkline?param={code},m{ts},,{count}'
st= json.loads(requests.get(URL).content); buf=st[ 'data'][code]['m'+str(ts)]
df=pd.DataFrame(buf,columns=['time','open','close','high','low','volume','n1','n2'])
df=df[['time','open','close','high','low','volume']]
df[['open','close','high','low','volume']]=df[['open','close','high','low','volume']].astype('float')
df.time=pd.to_datetime(df.time); df.set_index(['time'], inplace=True); df.index.name='' #处理索引
df['close'][-1]=float(st['data'][code]['qt'][code][3]) #最新基金数据是3位的
return df
#sina新浪全周期获取函数分钟线 5m,15m,30m,60m 日线1d=240m 周线1w=1200m 1月=7200m
def get_price_sina(code, end_date='', count=10, frequency='60m'): #新浪全周期获取函数
frequency=frequency.replace('1d','240m').replace('1w','1200m').replace('1M','7200m'); mcount=count
ts=int(frequency[:-1]) if frequency[:-1].isdigit() else 1 #解析K线周期数
if (end_date!='') & (frequency in ['240m','1200m','7200m']):
end_date=pd.to_datetime(end_date) if not isinstance(end_date,datetime.date) else end_date #转换成datetime
unit=4 if frequency=='1200m' else 29 if frequency=='7200m' else 1 #4,29多几个数据不影响速度
count=count+(datetime.datetime.now()-end_date).days//unit #结束时间到今天有多少天自然日(肯定 >交易日)
#print(code,end_date,count)
URL=f'http://money.finance.sina.com.cn/quotes_service/api/json_v2.php/CN_MarketData.getKLineData?symbol={code}&scale={ts}&ma=5&datalen={count}'
dstr= json.loads(requests.get(URL).content);
#df=pd.DataFrame(dstr,columns=['day','open','high','low','close','volume'],dtype='float')
df= pd.DataFrame(dstr,columns=['day','open','high','low','close','volume'])
df['open'] = df['open'].astype(float); df['high'] = df['high'].astype(float); #转换数据类型
df['low'] = df['low'].astype(float); df['close'] = df['close'].astype(float); df['volume'] = df['volume'].astype(float)
df.day=pd.to_datetime(df.day);
df.set_index(['day'], inplace=True);
df.index.name='' #处理索引
if (end_date!='') & (frequency in ['240m','1200m','7200m']):
return df[df.index<=end_date][-mcount:] #日线带结束时间先返回
return df
def get_price(code, end_date='',count=10, frequency='1d', fields=[]): #对外暴露只有唯一函数,这样对用户才是最友好的
xcode= code.replace('.XSHG','').replace('.XSHE','') #证券代码编码兼容处理
xcode='sh'+xcode if ('XSHG' in code) else 'sz'+xcode if ('XSHE' in code) else code
if frequency in ['1d','1w','1M']: #1d日线 1w周线 1M月线
try:
return get_price_sina( xcode, end_date=end_date,count=count,frequency=frequency) #主力
except:
return get_price_day_tx(xcode,end_date=end_date,count=count,frequency=frequency) #备用
if frequency in ['1m','5m','15m','30m','60m']: #分钟线 ,1m只有腾讯接口 5分钟5m 60分钟60m
if frequency in '1m':
return get_price_min_tx(xcode,end_date=end_date,count=count,frequency=frequency)
try:
return get_price_sina(xcode,end_date=end_date,count=count,frequency=frequency) #主力
except:
return get_price_min_tx(xcode,end_date=end_date,count=count,frequency=frequency) #备用
if __name__ == "__main__":
tool = StockInfo()
output = tool.run(code='sh600519', end_date='', count=10, frequency='15m')
print(output.json(indent=2))