189 lines
8.1 KiB
Python
189 lines
8.1 KiB
Python
|
||
import json
|
||
import os
|
||
import re
|
||
from pydantic import BaseModel, Field
|
||
from typing import List, Dict, Optional
|
||
import requests
|
||
import numpy as np
|
||
from loguru import logger
|
||
|
||
from .base_tool import BaseToolModel
|
||
from dev_opsgpt.utils.common_utils import read_json_file
|
||
|
||
|
||
cur_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
stock_infos = read_json_file(os.path.join(cur_dir, "../sources/tool_datas/stock.json"))
|
||
stock_dict = {i["mc"]: i["jys"]+i["dm"] for i in stock_infos}
|
||
|
||
|
||
class StockName(BaseToolModel):
|
||
"""
|
||
Tips
|
||
"""
|
||
name: str = "StockName"
|
||
description: str = "通过股票名称查询股票代码"
|
||
|
||
class ToolInputArgs(BaseModel):
|
||
"""Input for StockName"""
|
||
stock_name: int = Field(..., description="股票名称")
|
||
|
||
class ToolOutputArgs(BaseModel):
|
||
"""Output for StockName"""
|
||
stock_code: str = Field(..., description="股票代码")
|
||
|
||
@staticmethod
|
||
def run(stock_name: str):
|
||
return stock_dict.get(stock_name, "no stock_code")
|
||
|
||
|
||
|
||
class StockInfo(BaseToolModel):
|
||
"""
|
||
用于查询股票市场数据的StockInfo工具。
|
||
"""
|
||
|
||
name: str = "StockInfo"
|
||
description: str = "根据提供的股票代码、日期范围和数据频率提供股票市场数据。"
|
||
|
||
class ToolInputArgs(BaseModel):
|
||
"""StockInfo的输入参数。"""
|
||
code: str = Field(..., description="要查询的股票代码,格式为'marketcode'")
|
||
end_date: Optional[str] = Field(default="", description="数据查询的结束日期。留空则为当前日期。如果没有提供结束日期,就留空")
|
||
count: int = Field(default=10, description="返回数据点的数量。")
|
||
frequency: str = Field(default='1d', description="数据点的频率,例如,'1d'表示每日,'1w'表示每周,'1M'表示每月,'1m'表示每分钟等。")
|
||
|
||
class ToolOutputArgs(BaseModel):
|
||
"""StockInfo的输出参数。"""
|
||
data: dict = Field(default=None, description="查询到的股票市场数据。")
|
||
|
||
@staticmethod
|
||
def run(code: str, count: int, frequency: str, end_date: Optional[str]="") -> "ToolOutputArgs":
|
||
"""执行股票数据查询工具。"""
|
||
# 该方法封装了调用底层股票数据API的逻辑,并将结果格式化为pandas DataFrame。
|
||
try:
|
||
df = get_price(code, end_date=end_date, count=count, frequency=frequency)
|
||
# 将DataFrame转换为输出的字典格式
|
||
data = df.reset_index().to_dict(orient='list') # 将dataframe转换为字典列表
|
||
return StockInfo.ToolOutputArgs(data=data)
|
||
except Exception as e:
|
||
logger.exception("获取股票数据时发生错误。")
|
||
return e
|
||
|
||
#-*- coding:utf-8 -*- --------------Ashare 股票行情数据双核心版( https://github.com/mpquant/Ashare )
|
||
|
||
import json,requests,datetime
|
||
import pandas as pd #
|
||
|
||
#腾讯日线
|
||
|
||
def get_price_day_tx(code, end_date='', count=10, frequency='1d'): #日线获取
|
||
|
||
unit='week' if frequency in '1w' else 'month' if frequency in '1M' else 'day' #判断日线,周线,月线
|
||
|
||
if end_date: end_date=end_date.strftime('%Y-%m-%d') if isinstance(end_date,datetime.date) else end_date.split(' ')[0]
|
||
|
||
end_date='' if end_date==datetime.datetime.now().strftime('%Y-%m-%d') else end_date #如果日期今天就变成空
|
||
|
||
URL=f'http://web.ifzq.gtimg.cn/appstock/app/fqkline/get?param={code},{unit},,{end_date},{count},qfq'
|
||
|
||
st= json.loads(requests.get(URL).content); ms='qfq'+unit; stk=st['data'][code]
|
||
|
||
buf=stk[ms] if ms in stk else stk[unit] #指数返回不是qfqday,是day
|
||
|
||
df=pd.DataFrame(buf,columns=['time','open','close','high','low','volume'],dtype='float')
|
||
|
||
df.time=pd.to_datetime(df.time); df.set_index(['time'], inplace=True); df.index.name='' #处理索引
|
||
|
||
return df
|
||
|
||
|
||
#腾讯分钟线
|
||
|
||
def get_price_min_tx(code, end_date=None, count=10, frequency='1d'): #分钟线获取
|
||
|
||
ts=int(frequency[:-1]) if frequency[:-1].isdigit() else 1 #解析K线周期数
|
||
|
||
if end_date: end_date=end_date.strftime('%Y-%m-%d') if isinstance(end_date,datetime.date) else end_date.split(' ')[0]
|
||
|
||
URL=f'http://ifzq.gtimg.cn/appstock/app/kline/mkline?param={code},m{ts},,{count}'
|
||
|
||
st= json.loads(requests.get(URL).content); buf=st[ 'data'][code]['m'+str(ts)]
|
||
|
||
df=pd.DataFrame(buf,columns=['time','open','close','high','low','volume','n1','n2'])
|
||
|
||
df=df[['time','open','close','high','low','volume']]
|
||
|
||
df[['open','close','high','low','volume']]=df[['open','close','high','low','volume']].astype('float')
|
||
|
||
df.time=pd.to_datetime(df.time); df.set_index(['time'], inplace=True); df.index.name='' #处理索引
|
||
|
||
df['close'][-1]=float(st['data'][code]['qt'][code][3]) #最新基金数据是3位的
|
||
|
||
return df
|
||
|
||
|
||
#sina新浪全周期获取函数,分钟线 5m,15m,30m,60m 日线1d=240m 周线1w=1200m 1月=7200m
|
||
|
||
def get_price_sina(code, end_date='', count=10, frequency='60m'): #新浪全周期获取函数
|
||
|
||
frequency=frequency.replace('1d','240m').replace('1w','1200m').replace('1M','7200m'); mcount=count
|
||
|
||
ts=int(frequency[:-1]) if frequency[:-1].isdigit() else 1 #解析K线周期数
|
||
|
||
if (end_date!='') & (frequency in ['240m','1200m','7200m']):
|
||
|
||
end_date=pd.to_datetime(end_date) if not isinstance(end_date,datetime.date) else end_date #转换成datetime
|
||
unit=4 if frequency=='1200m' else 29 if frequency=='7200m' else 1 #4,29多几个数据不影响速度
|
||
|
||
count=count+(datetime.datetime.now()-end_date).days//unit #结束时间到今天有多少天自然日(肯定 >交易日)
|
||
|
||
#print(code,end_date,count)
|
||
|
||
URL=f'http://money.finance.sina.com.cn/quotes_service/api/json_v2.php/CN_MarketData.getKLineData?symbol={code}&scale={ts}&ma=5&datalen={count}'
|
||
|
||
dstr= json.loads(requests.get(URL).content);
|
||
|
||
#df=pd.DataFrame(dstr,columns=['day','open','high','low','close','volume'],dtype='float')
|
||
|
||
df= pd.DataFrame(dstr,columns=['day','open','high','low','close','volume'])
|
||
|
||
df['open'] = df['open'].astype(float); df['high'] = df['high'].astype(float); #转换数据类型
|
||
df['low'] = df['low'].astype(float); df['close'] = df['close'].astype(float); df['volume'] = df['volume'].astype(float)
|
||
|
||
df.day=pd.to_datetime(df.day);
|
||
|
||
df.set_index(['day'], inplace=True);
|
||
|
||
df.index.name='' #处理索引
|
||
|
||
if (end_date!='') & (frequency in ['240m','1200m','7200m']):
|
||
return df[df.index<=end_date][-mcount:] #日线带结束时间先返回
|
||
|
||
return df
|
||
|
||
|
||
def get_price(code, end_date='',count=10, frequency='1d', fields=[]): #对外暴露只有唯一函数,这样对用户才是最友好的
|
||
|
||
xcode= code.replace('.XSHG','').replace('.XSHE','') #证券代码编码兼容处理
|
||
xcode='sh'+xcode if ('XSHG' in code) else 'sz'+xcode if ('XSHE' in code) else code
|
||
|
||
if frequency in ['1d','1w','1M']: #1d日线 1w周线 1M月线
|
||
try:
|
||
return get_price_sina( xcode, end_date=end_date,count=count,frequency=frequency) #主力
|
||
except:
|
||
return get_price_day_tx(xcode,end_date=end_date,count=count,frequency=frequency) #备用
|
||
|
||
if frequency in ['1m','5m','15m','30m','60m']: #分钟线 ,1m只有腾讯接口 5分钟5m 60分钟60m
|
||
if frequency in '1m':
|
||
return get_price_min_tx(xcode,end_date=end_date,count=count,frequency=frequency)
|
||
try:
|
||
return get_price_sina(xcode,end_date=end_date,count=count,frequency=frequency) #主力
|
||
except:
|
||
return get_price_min_tx(xcode,end_date=end_date,count=count,frequency=frequency) #备用
|
||
|
||
|
||
if __name__ == "__main__":
|
||
tool = StockInfo()
|
||
output = tool.run(code='sh600519', end_date='', count=10, frequency='15m')
|
||
print(output.json(indent=2)) |