[update] gateway & data

This commit is contained in:
msincenselee 2021-09-04 13:48:21 +08:00
parent 98ab549e0d
commit b149b6cc18
17 changed files with 400 additions and 64 deletions

76
examples/stock/demo_01.py Normal file
View File

@ -0,0 +1,76 @@
# flake8: noqa
# 示例代码
# 从本地股票数据加载,前复权,显示主图指标、副图指标、缠论
import os
import sys
import json
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
if vnpy_root not in sys.path:
print(f'sys.path append({vnpy_root})')
sys.path.append(vnpy_root)
os.environ["VNPY_TESTING"] = "1"
from vnpy.data.tdx.tdx_common import FakeStrategy
from vnpy.data.tdx.tdx_stock_data import *
from vnpy.component.cta_line_bar import CtaMinuteBar
from vnpy.trader.ui.kline.ui_snapshot import UiSnapshot
from vnpy.trader.ui import create_qapp
from vnpy.data.common import get_stock_bars
if __name__ == "__main__":
# 创建一个假的策略
t1 = FakeStrategy()
# 股票代码.交易所
vt_symbol = '000001.SZSE'
# 数据周期
bar_freq = '15m'
# 一根bar代表的分钟数
bar_interval = int(bar_freq.replace('m', ''))
# 获取某个合约得的分时数据,周期是15分钟返回数据类型是barData
print('加载数据')
bars, msg = get_stock_bars(vt_symbol=vt_symbol, freq=bar_freq)
# 创建一个15分钟bar的 kline对象
setting = {}
setting['name'] = f'{vt_symbol}_{bar_freq}'
setting['bar_interval'] = bar_interval
setting['para_ma1_len'] = 55 # 双均线
setting['para_ma2_len'] = 89
setting['para_macd_fast_len'] = 12 # 激活macd
setting['para_macd_slow_len'] = 26
setting['para_macd_signal_len'] = 9
setting['para_active_chanlun'] = True # 激活缠论
setting['price_tick'] = 1
setting['is_stock'] = True
setting['underly_symbol'] = vt_symbol.split('.')[0]
kline = CtaMinuteBar(strategy=t1, cb_on_bar=None, setting=setting)
# 推送bar到kline中
for bar in bars:
kline.add_bar(bar, bar_is_completed=True, bar_freq=bar_interval)
# 获取kline的切片数据
data = kline.get_data()
snapshot = {
'strategy': "demo",
'datetime': datetime.now(),
"kline_names": [kline.name],
"klines": {kline.name: data}}
# 创建一个GUI界面应用app
qApp = create_qapp()
# 创建切片回放工具窗口
ui = UiSnapshot()
# 显示切片内容
ui.show(snapshot_file="",
d=snapshot)
sys.exit(qApp.exec_())

View File

@ -147,18 +147,15 @@ def refill(symbol_info):
# thread_tasks.append(task) # thread_tasks.append(task)
def resample(symbol, exchange, x_mins=[5, 15, 30]): def resample(vt_symbol, x_mins=[5, 15, 30]):
""" """
更新多周期文件 更新多周期文件
:param symbol: :param vt_symbol: 代码.交易所
:param exchange:
:param x_mins: :param x_mins:
:return: :return:
""" """
d1 = datetime.now() d1 = datetime.now()
out_files, err_msg = resample_bars_file(vnpy_root=vnpy_root, out_files, err_msg = resample_bars_file(vt_symbol=vt_symbol,
symbol=symbol,
exchange=exchange,
x_mins=x_mins) x_mins=x_mins)
d2 = datetime.now() d2 = datetime.now()
microseconds = round((d2 - d1).microseconds / 100, 0) microseconds = round((d2 - d1).microseconds / 100, 0)

View File

@ -5,7 +5,7 @@ import pika
import random import random
import traceback import traceback
from vnpy.amqp.base import base_broker from vnpy.amqp.base import base_broker
from vnpy.component.base import MyEncoder
# 模式1接收者 # 模式1接收者
class receiver(base_broker): class receiver(base_broker):
@ -307,7 +307,7 @@ class rpc_server(base_broker):
def reply(self, chan, reply_data, reply_to, reply_id, delivery_tag): def reply(self, chan, reply_data, reply_to, reply_id, delivery_tag):
"""返回调用结果""" """返回调用结果"""
# data => string # data => string
reply_msg = json.dumps(reply_data) reply_msg = json.dumps(reply_data,cls=MyEncoder)
# 发送返回消息 # 发送返回消息
chan.basic_publish(exchange=self.exchange, chan.basic_publish(exchange=self.exchange,
routing_key=reply_to, routing_key=reply_to,

View File

@ -114,7 +114,7 @@ class RemoteClient:
# 整个接口对外保持和原来的一致 # 整个接口对外保持和原来的一致
# 通过对原requests接口的“鸭子类型替换”来实现透明化 # 通过对原requests接口的“鸭子类型替换”来实现透明化
def use(broker, host, port=1430, use_zmq=True, **kwargs): def use(broker, host, port=1430, use_zmq=False, **kwargs):
if use_zmq: if use_zmq:
return ZMQRemoteClient(broker, host, port) return ZMQRemoteClient(broker, host, port)
else: else:

View File

@ -140,6 +140,7 @@ class RestClient(object):
self.logger: Optional[logging.Logger] = None self.logger: Optional[logging.Logger] = None
self.proxies = None self.proxies = None
self.cookies = {}
self.thread_executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 20) self.thread_executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 20)
@ -436,8 +437,13 @@ class RestClient(object):
if status_code == 204: if status_code == 204:
json_body = None json_body = None
else: else:
try:
json_body = response.json() json_body = response.json()
except Exception as ex:
json_body = response.content.decode('utf-8')
self._process_json_body(json_body, request) self._process_json_body(json_body, request)
if response.cookies.get_dict():
self.cookies.update(response.cookies)
else: else:
if request.on_failed: if request.on_failed:
request.status = RequestStatus.failed request.status = RequestStatus.failed
@ -462,7 +468,7 @@ class RestClient(object):
else: else:
self.on_error(t, v, tb, request) self.on_error(t, v, tb, request)
def _process_json_body(self, json_body: Optional[dict], request: "Request"): def _process_json_body(self, json_body: Union[dict,str], request: "Request"):
status_code = request.response.status_code status_code = request.response.status_code
if self.is_request_success(json_body, request): if self.is_request_success(json_body, request):
request.status = RequestStatus.success request.status = RequestStatus.success

View File

@ -243,7 +243,7 @@ class AccountRecorder(BaseEngine):
end_day = dt_now.strftime('%Y%m%d') end_day = dt_now.strftime('%Y%m%d')
gw = self.main_engine.get_gateway(gw_name) gw = self.main_engine.get_gateway(gw_name)
if gw is None: if gw is None:
continue self.write_log(f'Account_recorder找不到{gw_name}')
if hasattr(gw, 'qryHistory'): if hasattr(gw, 'qryHistory'):
self.write_log(u'{}请求{}数据,{}~{}'.format(gw_name, data_type, begin_day, end_day)) self.write_log(u'{}请求{}数据,{}~{}'.format(gw_name, data_type, begin_day, end_day))
gw.qryHistory(data_type, begin_day, end_day) gw.qryHistory(data_type, begin_day, end_day)

View File

@ -20,7 +20,7 @@ from .base import StopOrder
from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade
from vnpy.component.cta_position import CtaPosition from vnpy.component.cta_position import CtaPosition
from vnpy.component.cta_policy import CtaPolicy from vnpy.component.cta_policy import CtaPolicy
from vnpy.component.base import MyEncoder
class CtaTemplate(ABC): class CtaTemplate(ABC):
"""CTA策略模板""" """CTA策略模板"""
@ -1368,7 +1368,7 @@ class CtaFutureTemplate(CtaTemplate):
if policy: if policy:
op = getattr(policy, 'to_json', None) op = getattr(policy, 'to_json', None)
if callable(op): if callable(op):
self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False))) self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False,cls=MyEncoder)))
def save_dist(self, dist_data): def save_dist(self, dist_data):
""" """
@ -2152,7 +2152,7 @@ class CtaSpotTemplate(CtaTemplate):
if policy: if policy:
op = getattr(policy, 'to_json', None) op = getattr(policy, 'to_json', None)
if callable(op): if callable(op):
self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False))) self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False,cls=MyEncoder)))
def save_dist(self, dist_data): def save_dist(self, dist_data):
""" """

View File

@ -20,6 +20,7 @@ from .base import StopOrder,EngineType
from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade
from vnpy.component.cta_position import CtaPosition from vnpy.component.cta_position import CtaPosition
from vnpy.component.cta_policy import CtaPolicy from vnpy.component.cta_policy import CtaPolicy
from vnpy.component.base import MyEncoder
class CtaTemplate(ABC): class CtaTemplate(ABC):
"""CTA股票策略模板""" """CTA股票策略模板"""
@ -602,7 +603,7 @@ class CtaStockTemplate(CtaTemplate):
"""初始化Policy""" """初始化Policy"""
self.write_log(u'init_policy(),初始化执行逻辑') self.write_log(u'init_policy(),初始化执行逻辑')
self.policy.load() self.policy.load()
self.write_log('{}'.format(json.dumps(self.policy.to_json(),indent=2, ensure_ascii=False))) self.write_log('{}'.format(json.dumps(self.policy.to_json(),indent=2, ensure_ascii=False,cls=MyEncoder)))
def init_position(self): def init_position(self):
""" """
@ -1075,9 +1076,12 @@ class CtaStockTemplate(CtaTemplate):
continue continue
# 实盘运行时,要加入市场买卖量的判断 # 实盘运行时,要加入市场买卖量的判断
limit_down = None
if not force and not self.backtesting: if not force and not self.backtesting:
symbol_tick = self.cta_engine.get_tick(vt_symbol) symbol_tick = self.cta_engine.get_tick(vt_symbol)
if symbol_tick: if symbol_tick:
if symbol_tick.limit_down > 0:
limit_down = symbol_tick.limit_down
symbol_volume_tick = self.cta_engine.get_volume_tick(vt_symbol) symbol_volume_tick = self.cta_engine.get_volume_tick(vt_symbol)
# 根据市场计算前5档买单数量 # 根据市场计算前5档买单数量
if all([symbol_tick.ask_volume_1, symbol_tick.ask_volume_2, symbol_tick.ask_volume_3, if all([symbol_tick.ask_volume_1, symbol_tick.ask_volume_2, symbol_tick.ask_volume_3,
@ -1095,7 +1099,11 @@ class CtaStockTemplate(CtaTemplate):
self.write_log(u'修正批次卖出{}数量:{}=>{}'.format(vt_symbol, org_sell_volume, sell_volume)) self.write_log(u'修正批次卖出{}数量:{}=>{}'.format(vt_symbol, org_sell_volume, sell_volume))
# 获取当前价格 # 获取当前价格
if limit_down is None or cur_price > limit_down:
sell_price = cur_price - self.cta_engine.get_price_tick(vt_symbol) sell_price = cur_price - self.cta_engine.get_price_tick(vt_symbol)
else:
sell_price = cur_price
# 发出委托卖出 # 发出委托卖出
vt_orderids = self.sell( vt_orderids = self.sell(
vt_symbol=vt_symbol, vt_symbol=vt_symbol,
@ -1134,7 +1142,7 @@ class CtaStockTemplate(CtaTemplate):
dist_record = dict() dist_record = dict()
dist_record['volume'] = grid.volume dist_record['volume'] = grid.volume
dist_record['price'] = self.cta_engine.get_price(grid.vt_symbol) dist_record['price'] = self.cta_engine.get_price(grid.vt_symbol)
dist_record['operation'] = 'execute finished' dist_record['operation'] = 'sell finished'
dist_record['signal'] = grid.type dist_record['signal'] = grid.type
self.save_dist(dist_record) self.save_dist(dist_record)
@ -1322,6 +1330,11 @@ class CtaStockTemplate(CtaTemplate):
elif order_status == Status.CANCELLED: elif order_status == Status.CANCELLED:
self.write_log(u'委托单{}已成功撤单,删除{}'.format(vt_orderid, order_info)) self.write_log(u'委托单{}已成功撤单,删除{}'.format(vt_orderid, order_info))
canceled_ids.append(vt_orderid) canceled_ids.append(vt_orderid)
elif order_status == Status.CANCELLING:
if over_seconds > self.cancel_seconds * 3:
self.write_log(u'委托单{}正在撤单,超时{},删除{}'.format(vt_orderid,over_seconds, order_info))
canceled_ids.append(vt_orderid)
# 删除撤单的订单 # 删除撤单的订单
for vt_orderid in canceled_ids: for vt_orderid in canceled_ids:
@ -1381,7 +1394,7 @@ class CtaStockTemplate(CtaTemplate):
policy = getattr(self, 'policy') policy = getattr(self, 'policy')
op = getattr(policy, 'to_json', None) op = getattr(policy, 'to_json', None)
if callable(op): if callable(op):
self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False))) self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False,cls=MyEncoder)))
def save_dist(self, dist_data): def save_dist(self, dist_data):
""" """

View File

@ -76,7 +76,7 @@ from .base import (
STOPORDER_PREFIX, STOPORDER_PREFIX,
) )
from .template import CtaTemplate from .template import CtaTemplate
from vnpy.component.base import MARKET_DAY_ONLY from vnpy.component.base import MARKET_DAY_ONLY, MyEncoder
from vnpy.component.cta_position import CtaPosition from vnpy.component.cta_position import CtaPosition
STOP_STATUS_MAP = { STOP_STATUS_MAP = {

View File

@ -165,6 +165,7 @@ class CtaLineBar(object):
self.price_tick = 1 # 商品的最小价格单位 self.price_tick = 1 # 商品的最小价格单位
self.round_n = 4 # round() 小数点的截断数量 self.round_n = 4 # round() 小数点的截断数量
self.is_7x24 = False # 是否7x24小时运行 一般为数字货币) self.is_7x24 = False # 是否7x24小时运行 一般为数字货币)
self.is_stock = False # 是否为股票
# 当前的Tick的信息 # 当前的Tick的信息
self.cur_tick = None # 当前 onTick()函数接收的 最新的tick self.cur_tick = None # 当前 onTick()函数接收的 最新的tick
@ -230,6 +231,8 @@ class CtaLineBar(object):
self.minute_interval = None # 把各个周期的bar转换为分钟在first_tick中用来修正bar为整点分钟周期 self.minute_interval = None # 把各个周期的bar转换为分钟在first_tick中用来修正bar为整点分钟周期
if setting: if setting:
self.set_params(setting) self.set_params(setting)
if self.is_stock:
self.is_7x24 = True
# 修正self.minute_interval # 修正self.minute_interval
if self.interval == Interval.SECOND: if self.interval == Interval.SECOND:
@ -283,6 +286,7 @@ class CtaLineBar(object):
self.param_list.append('interval') # bar的类型 self.param_list.append('interval') # bar的类型
self.param_list.append('mode') # tick/bar模式 self.param_list.append('mode') # tick/bar模式
self.param_list.append('is_7x24') # 是否为7X24小时运行的bar一般为数字货币) self.param_list.append('is_7x24') # 是否为7X24小时运行的bar一般为数字货币)
self.param_list.append('is_stock') # 是否为7X24小时运行的bar一般为数字货币)
self.param_list.append('price_tick') # 最小跳动,用于处理指数等不一致的价格 self.param_list.append('price_tick') # 最小跳动,用于处理指数等不一致的价格
self.param_list.append('underly_symbol') # 短合约, self.param_list.append('underly_symbol') # 短合约,

View File

@ -1,20 +1,196 @@
import os import os
import pandas as pd import pandas as pd
import numpy as np
from typing import Union, List
from datetime import datetime
# 所有股票的复权因子
STOCK_ADJUST_FACTORS = {}
def get_bardata_folder(data_folder: str) -> str:
"""
如果data_folder为空白就返回bar_data的目录
:param data_folder:
:return:
"""
if len(data_folder) == 0 or not os.path.exists(data_folder):
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
data_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data'))
return data_folder
def get_stock_bars(vt_symbol:str,
freq: str = "1d",
start_date: str = "",
fq_type:str ="qfq") -> (List, str):
"""
获取本地文件的股票bar数据
:param vt_symbol:
:param freq:
:param start_date: 20180101 或者 2018-01-01
:param fq_type: qfq:前复权hfq:后复权; 空白:不复权
:return:
"""
# 获取未复权的bar dataframe数据
df, err_msg = get_stock_raw_data(vt_symbol=vt_symbol, freq=freq, start_date=start_date)
bars = []
if len(err_msg) > 0 or df is None:
return bars, err_msg
if fq_type != "":
from vnpy.data.stock.adjust_factor import get_all_adjust_factor
STOCK_ADJUST_FACTORS = get_all_adjust_factor()
adj_list = STOCK_ADJUST_FACTORS.get(vt_symbol, [])
if len(adj_list) > 0:
for row in adj_list:
row.update({'dividOperateDate': row.get('dividOperateDate')[:10] + ' 09:30:00'})
# list -> dataframe, 转换复权日期格式
adj_data = pd.DataFrame(adj_list)
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S")
adj_data = adj_data.set_index("dividOperateDate")
# 调用转换方法对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
df = stock_to_adj(df, adj_data, adj_type='fore' if fq_type == 'qfw' else 'back')
from vnpy.trader.object import BarData
from vnpy.trader.constant import Exchange
symbol, exchange = vt_symbol.split('.')
for dt, bar_data in df.iterrows():
bar_datetime = dt # - timedelta(seconds=bar_interval_seconds)
bar = BarData(
gateway_name='backtesting',
symbol=symbol,
exchange=Exchange(exchange),
datetime=bar_datetime
)
if 'open' in bar_data:
bar.open_price = float(bar_data['open'])
bar.close_price = float(bar_data['close'])
bar.high_price = float(bar_data['high'])
bar.low_price = float(bar_data['low'])
else:
bar.open_price = float(bar_data['open_price'])
bar.close_price = float(bar_data['close_price'])
bar.high_price = float(bar_data['high_price'])
bar.low_price = float(bar_data['low_price'])
bar.volume = int(bar_data['volume']) if not np.isnan(bar_data['volume']) else 0
bar.date = dt.strftime('%Y-%m-%d')
bar.time = dt.strftime('%H:%M:%S')
str_td = str(bar_data.get('trading_day', ''))
if len(str_td) == 8:
bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8]
else:
bar.trading_day = bar.date
bars.append(bar)
return bars, ""
def get_stock_raw_data(vt_symbol: str,
freq: str = "1d",
start_date: str = "",
bar_data_folder: str = "") -> (Union[pd.DataFrame, None], str):
"""
获取本地bar_data下的 交易所/股票代码_时间周期.csv原始bar数据未复权
:param vt_symbol: 600001.SSE 600001
:param freq: 1m,5m, 15m, 30m, 1h, 1d
:param start_date: 开始日期
:param bar_data_folder: 强制指定bar_data所在目录
:return: DataFrame, err_msg
"""
symbol, exchange = vt_symbol.split('.')
# 1分钟 csv文件路径
csv_file = os.path.abspath(os.path.join(
get_bardata_folder(bar_data_folder),
exchange,
f'{symbol}_{freq}.csv'))
if not os.path.exists(csv_file):
err_msg = f'{csv_file} 文件不存在,不能读取'
return None, err_msg
try:
# 载入原始csv => dataframe
df = pd.read_csv(csv_file)
datetime_format = "%Y-%m-%d %H:%M:%S"
# 转换时间str =》 datetime
df["datetime"] = pd.to_datetime(df["datetime"], format=datetime_format)
# 使用'datetime'字段作为索引
df.set_index("datetime", inplace=True)
if len(start_date) > 0:
if len(start_date) == 8:
_format = '%Y%m%d'
else:
_format = '%Y-%m-%d'
start_date = datetime.strptime(start_date, _format)
df = df.loc[start_date:]
return df, ""
except Exception as ex:
err_msg = f'读取异常:{str(ex)}'
return None, err_msg
def resample_bars_file(vnpy_root, symbol, exchange, x_mins=[], include_day=False): def stock_to_adj(raw_data: pd.DataFrame,
adj_data: pd.DataFrame,
adj_type: str) -> pd.DataFrame:
"""
股票数据复权转换
:param raw_data: 不复权数据
:param adj_data: 复权记录 ( 从barstock下载的复权记录列表=df
:param adj_type: 复权类型, fore, 前复权 back,后复权
:return:
"""
if adj_type == 'fore':
adj_factor = adj_data["foreAdjustFactor"]
adj_factor = adj_factor / adj_factor.iloc[-1] # 保证最后一个复权因子是1
else:
adj_factor = adj_data["backAdjustFactor"]
adj_factor = adj_factor / adj_factor.iloc[0] # 保证第一个复权因子是1
# 把raw_data的第一个日期插入复权因子df使用后填充
if adj_factor.index[0] != raw_data.index[0]:
adj_factor.loc[raw_data.index[0]] = np.nan
adj_factor.sort_index(inplace=True)
adj_factor = adj_factor.ffill()
adj_factor = adj_factor.reindex(index=raw_data.index) # 按价格dataframe的日期索引来扩展索引
adj_factor = adj_factor.ffill() # 向前(向未来)填充扩展后的空单元格
# 把复权因子作为adj字段补充到raw_data中
raw_data['adj'] = adj_factor
# 逐一复权高低开平和成交量
for col in ['open', 'high', 'low', 'close']:
raw_data[col] = raw_data[col] * raw_data['adj'] # 价格乘上复权系数
raw_data['volume'] = raw_data['volume'] / raw_data['adj'] # 成交量除以复权系数
return raw_data
def resample_bars_file(vt_symbol: str,
x_mins: List[str] = [],
include_day: bool = False,
bar_data_folder: str = "") -> (list, str):
""" """
重建x分钟K线和日线csv文件 重建x分钟K线和日线csv文件
:param symbol: :param vt_symbol: 代码.交易所
:param x_mins: [5, 15, 30, 60] :param x_mins: [5, 15, 30, 60]
:param include_day: 是否也重建日线 :param include_day: 是否也重建日线
:param vnpy_root: 项目所在根目录
:return: out_files,err_msg :return: out_files,err_msg
""" """
err_msg = "" err_msg = ""
out_files = [] out_files = []
symbol, exchange = vt_symbol.split('.')
# 1分钟 csv文件路径 # 1分钟 csv文件路径
csv_file = os.path.abspath(os.path.join(vnpy_root, 'bar_data', exchange.value, f'{symbol}_1m.csv')) csv_file = os.path.abspath(os.path.join(get_bardata_folder(bar_data_folder), exchange, f'{symbol}_1m.csv'))
if not os.path.exists(csv_file): if not os.path.exists(csv_file):
err_msg = f'{csv_file} 文件不存在,不能转换' err_msg = f'{csv_file} 文件不存在,不能转换'
@ -49,10 +225,13 @@ def resample_bars_file(vnpy_root, symbol, exchange, x_mins=[], include_day=False
for x_min in x_mins: for x_min in x_mins:
# 目标文件 # 目标文件
target_file = os.path.abspath( target_file = os.path.abspath(
os.path.join(vnpy_root, 'bar_data', exchange.value, f'{symbol}_{x_min}m.csv')) os.path.join(get_bardata_folder(bar_data_folder), exchange, f'{symbol}_{x_min}m.csv'))
# 合成x分钟K线并删除为空的行 参数 closedleft类似向上取值既 0930的k线数据是包含0930-0935之间的数据 # 合成x分钟K线并删除为空的行 参数 closedleft类似向上取值既 0930的k线数据是包含0930-0935之间的数据
#df_target = df_1m.resample(f'{x_min}min', how=ohlc_rule, closed='left', label='left').dropna(axis=0, how='any') # df_target = df_1m.resample(f'{x_min}min', how=ohlc_rule, closed='left', label='left').dropna(axis=0, how='any')
df_target = df_1m.resample(f'{x_min}min', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, df_target = df_1m.resample(
f'{x_min}min',
closed='left',
label='left').agg(ohlc_rule).dropna(axis=0,
how='any') how='any')
# dropna(axis=0, how='any') axis参数0针对行进行操作 1针对列进行操作 how参数any只要包含就删除 all全是为NaN才删除 # dropna(axis=0, how='any') axis参数0针对行进行操作 1针对列进行操作 how参数any只要包含就删除 all全是为NaN才删除
@ -64,10 +243,13 @@ def resample_bars_file(vnpy_root, symbol, exchange, x_mins=[], include_day=False
if include_day: if include_day:
# 目标文件 # 目标文件
target_file = os.path.abspath( target_file = os.path.abspath(
os.path.join(vnpy_root, 'bar_data', exchange.value, f'{symbol}_1d.csv')) os.path.join(get_bardata_folder(bar_data_folder), exchange, f'{symbol}_1d.csv'))
# 合成x分钟K线并删除为空的行 参数 closedleft类似向上取值既 0930的k线数据是包含0930-0935之间的数据 # 合成x分钟K线并删除为空的行 参数 closedleft类似向上取值既 0930的k线数据是包含0930-0935之间的数据
# df_target = df_1m.resample(f'D', how=ohlc_rule, closed='left', label='left').dropna(axis=0, how='any') # df_target = df_1m.resample(f'D', how=ohlc_rule, closed='left', label='left').dropna(axis=0, how='any')
df_target = df_1m.resample(f'D', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, how='any') df_target = df_1m.resample(
f'D',
closed='left',
label='left').agg(ohlc_rule).dropna(axis=0, how='any')
# dropna(axis=0, how='any') axis参数0针对行进行操作 1针对列进行操作 how参数any只要包含就删除 all全是为NaN才删除 # dropna(axis=0, how='any') axis参数0针对行进行操作 1针对列进行操作 how参数any只要包含就删除 all全是为NaN才删除
if len(df_target) > 0: if len(df_target) > 0:
@ -75,4 +257,4 @@ def resample_bars_file(vnpy_root, symbol, exchange, x_mins=[], include_day=False
print(f'生成[日线] => {target_file}') print(f'生成[日线] => {target_file}')
out_files.append(target_file) out_files.append(target_file)
return out_files,err_msg return out_files, err_msg

View File

@ -130,4 +130,21 @@ def download_adjust_factor():
return factor_dict return factor_dict
if __name__ == '__main__': if __name__ == '__main__':
download_adjust_factor()
# 下载所有复权数据
# download_adjust_factor()
# 下载某个股票的复权数据
# f = get_adjust_factor(vt_symbol='000651.SZSE',stock_name='格力电器',need_login=True)
#
# for d in f:
# print(d)
# 读取缓存文件中某只股票的复权数据
factors = get_all_adjust_factor()
f = factors.get('000651.SZSE',None)
if f is None:
print('获取不到数据')
else:
for d in f:
print(d)

View File

@ -34,7 +34,7 @@ TDX_PROXY_CONFIG = 'tdx_proxy_config.json'
def get_tdx_market_code(code): def get_tdx_market_code(code):
# 获取通达信股票的market code # 获取通达信股票的market code
code = str(code) code = str(code)
if code[0] in ['5', '6', '9'] or code[:3] in ["009", "126", "110", "201", "202", "203", "204"]: if code[0] in ['5', '6', '9'] or code[:3] in ["880","009", "126", "110", "201", "202", "203", "204"]:
# 上海证券交易所 # 上海证券交易所
return 1 return 1
# 深圳证券交易所 # 深圳证券交易所
@ -101,10 +101,13 @@ def get_cache_config(config_file_name):
config = {} config = {}
if not os.path.exists(config_file_name): if not os.path.exists(config_file_name):
return config return config
try:
with bz2.BZ2File(config_file_name, 'rb') as f: with bz2.BZ2File(config_file_name, 'rb') as f:
config = pickle.load(f) config = pickle.load(f)
return config return config
except Exception as ex:
print(f'读取缓存本地文件:{config_file_name}异常{str(ex)}')
return config
def save_cache_config(data: dict, config_file_name): def save_cache_config(data: dict, config_file_name):
"""保存本地缓存的配置地址信息""" """保存本地缓存的配置地址信息"""
@ -126,8 +129,9 @@ def save_cache_json(data_dict: dict, json_file_name: str):
save_json(filename=config_file_name, data=data_dict) save_json(filename=config_file_name, data=data_dict)
def get_stock_type(code): def get_stock_type(code,market_id = None ):
"""获取股票得分类""" """获取股票得分类"""
if market_id is None:
market_id = get_tdx_market_code(code) market_id = get_tdx_market_code(code)
if market_id == 0: if market_id == 0:

View File

@ -314,13 +314,18 @@ class TdxStockData(object):
self.write_log('{}开始下载tdx股票: {},代码:{} {}数据, {} to {}.' self.write_log('{}开始下载tdx股票: {},代码:{} {}数据, {} to {}.'
.format(datetime.now(), name, tdx_code, tdx_period, qry_start_date, qry_end_date)) .format(datetime.now(), name, tdx_code, tdx_period, qry_start_date, qry_end_date))
stock_type = get_stock_type(tdx_code,market_id)
if stock_type == 'index_cn':
get_bar_func = self.api.get_index_bars
else:
get_bar_func = self.api.get_security_bars
try: try:
_start_date = qry_end_date _start_date = qry_end_date
_bars = [] _bars = []
_pos = 0 _pos = 0
while _start_date > qry_start_date: while _start_date > qry_start_date:
_res = self.api.get_security_bars( _res = get_bar_func(
category=PERIOD_MAPPING[period], category=PERIOD_MAPPING[period],
market=market_id, market=market_id,
code=tdx_code, code=tdx_code,
@ -452,8 +457,13 @@ class TdxStockData(object):
.format(datetime.now(), period, list(PERIOD_MAPPING.keys()))) .format(datetime.now(), period, list(PERIOD_MAPPING.keys())))
return False, ret_bars return False, ret_bars
tdx_period = PERIOD_MAPPING.get(period) tdx_period = PERIOD_MAPPING.get(period)
stock_type = get_stock_type(tdx_code)
if stock_type == 'index_cn':
get_bar_func = self.api.get_index_bars
else:
get_bar_func = self.api.get_security_bars
try: try:
datas = self.api.get_security_bars( datas = get_bar_func(
category=PERIOD_MAPPING[period], category=PERIOD_MAPPING[period],
market=market_id, market=market_id,
code=tdx_code, code=tdx_code,
@ -490,6 +500,8 @@ class TdxStockData(object):
self.write_error(f'获取{symbol}数据失败:{str(ex)}') self.write_error(f'获取{symbol}数据失败:{str(ex)}')
return False, ret_bars return False, ret_bars
# ----------------------------------------------------------------------
def save_cache(self, def save_cache(self,
cache_folder: str, cache_folder: str,
cache_symbol: str, cache_symbol: str,

View File

@ -19,26 +19,36 @@ t1 = FakeStrategy()
t2 = FakeStrategy() t2 = FakeStrategy()
# 创建API对象(使用本地socket5代理 # 创建API对象(使用本地socket5代理
api_01 = TdxStockData(strategy=t1, proxy_ip='localhost', proxy_port=1080) #api_01 = TdxStockData(strategy=t1, proxy_ip='localhost', proxy_port=1080)
# 不使用代理 # 不使用代理
#api_01 = TdxStockData(strategy=t1) api_01 = TdxStockData(strategy=t1)
#
# 获取市场下股票 # # 获取市场下股票
for market_id in range(2): # for market_id in range(2):
print('get market_id:{}'.format(market_id)) # print('get market_id:{}'.format(market_id))
security_list = api_01.get_security_list(market_id) # security_list = api_01.get_security_list(market_id)
if len(security_list) == 0: # if len(security_list) == 0:
continue # continue
for security in security_list: # for security in security_list:
if security.get('code', '').startswith('12') or u'转债' in security.get('name', ''): # if security['code'].startswith('88'):
str_security = json.dumps(security, indent=1, ensure_ascii=False) # str_security = json.dumps(security, indent=1, ensure_ascii=False)
print('market_id:{},{}'.format(market_id, str_security)) # print(str_security)
# if security.get('code', '').startswith('12') or u'转债' in security.get('name', ''):
# str_security = json.dumps(security, indent=1, ensure_ascii=False)
# # print('market_id:{},{}'.format(market_id, str_security))
# str_markets = json.dumps(security_list, indent=1, ensure_ascii=False) # str_markets = json.dumps(security_list, indent=1, ensure_ascii=False)
# print(u'{}'.format(str_markets)) # print(u'{}'.format(str_markets))
# 获取历史分钟线 # 获取历史分钟线
# api_01.get_bars('002024', period='1hour', callback=t1.display_bar) ret,result = api_01.get_bars('880351.SSE', period='1hour', callback=t1.display_bar)
if ret:
for bar in result:
print(bar)
# ret,result = api_01.get_last_bars(symbol='002024',return_bar=True)
# if ret:
# print(result)
# api.get_bars(symbol, period='5min', callback=display_bar) # api.get_bars(symbol, period='5min', callback=display_bar)
# api.get_bars(symbol, period='1day', callback=display_bar) # api.get_bars(symbol, period='1day', callback=display_bar)
@ -50,7 +60,7 @@ for market_id in range(2):
# for r in result[0:10] + result[-10:]: # for r in result[0:10] + result[-10:]:
# print(r) # print(r)
# 获取历史分时数据 # # 获取历史分时数据
ret, result = api_01.get_history_transaction_data('110031', '20200504') # ret, result = api_01.get_history_transaction_data('110031', '20200504')
for r in result[0:10] + result[-10:]: # for r in result[0:10] + result[-10:]:
print(r) # print(r)

View File

@ -940,11 +940,21 @@ class ThsTdApi(object):
return return
if '总资产' not in data: if '总资产' not in data:
return return
## 为了兼容东财的webapi这里frozen做个特殊处理
# if "冻结金额" in data:
# # 同花顺直接给了冻结金额
# frozen = float(data["冻结金额"])
# else:
# # 东财没有冻结金额这个项目,要计算
# frozen = float(data["总资产"]) - float(data["资金余额"])
frozen = 0
account = AccountData( account = AccountData(
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
accountid=self.userid, accountid=self.userid,
balance=float(data["总资产"]), balance=float(data["总资产"]),
frozen=float(data["总资产"]) - float(data["资金余额"]), frozen=frozen,
currency="人民币", currency="人民币",
trading_day=self.trading_day trading_day=self.trading_day
) )

View File

@ -578,16 +578,20 @@ class IndexGenerator:
self.exchange = setting.get('exchange', None) self.exchange = setting.get('exchange', None)
self.price_tick = setting.get('price_tick') self.price_tick = setting.get('price_tick')
self.symbols = setting.get('symbols', {}) self.symbols = setting.get('symbols', {})
self.pre_oi_total = 1
# 订阅行情 # 订阅行情
self.subscribe() self.subscribe()
self.n = len(self.symbols) self.n = len(self.symbols)
def subscribe(self): def subscribe(self):
"""订阅行情""" """订阅行情"""
dt_now = datetime.now() dt_now = datetime.now()
for symbol in list(self.symbols.keys()): for symbol in list(self.symbols.keys()):
pre_open_interest = self.symbols.get(symbol,0) pre_open_interest = self.symbols.get(symbol,0)
self.pre_oi_total += pre_open_interest
# 全路径合约 => 标准合约 ,如 ZC2109 => ZC109, RB2110 => rb2110 # 全路径合约 => 标准合约 ,如 ZC2109 => ZC109, RB2110 => rb2110
vn_symbol = get_real_symbol_by_exchange(symbol, Exchange(self.exchange)) vn_symbol = get_real_symbol_by_exchange(symbol, Exchange(self.exchange))
# 先移除 # 先移除
@ -596,6 +600,9 @@ class IndexGenerator:
self.gateway.write_log(f'移除早于当月的合约{symbol}') self.gateway.write_log(f'移除早于当月的合约{symbol}')
continue continue
if pre_open_interest < 100:
self.gateway.write_log(f'移除持仓量:{pre_open_interest}低于100的合约{symbol}')
continue
# 重新登记合约 # 重新登记合约
self.symbols[vn_symbol] = pre_open_interest self.symbols[vn_symbol] = pre_open_interest
@ -625,12 +632,6 @@ class IndexGenerator:
bid_price_1 = 0 bid_price_1 = 0
mi_tick = None mi_tick = None
# 已经积累的行情tick数量不足总数减1不处理
if len(self.ticks) < min(self.n * 0.8, 3):
self.gateway.write_log(f'{self.underlying_symbol}合约数据{len(self.ticks)}不足{self.n} 0.8,暂不合成指数')
return
# 计算所有合约的累加持仓量、资金、成交量、找出最大持仓量的主力合约 # 计算所有合约的累加持仓量、资金、成交量、找出最大持仓量的主力合约
for t in self.ticks.values(): for t in self.ticks.values():
all_interest += t.open_interest all_interest += t.open_interest
@ -641,6 +642,10 @@ class IndexGenerator:
if mi_tick is None or mi_tick.open_interest < t.open_interest: if mi_tick is None or mi_tick.open_interest < t.open_interest:
mi_tick = t mi_tick = t
if not (len(self.ticks) > min(self.n * 0.7, 3) or all_interest > self.pre_oi_total * 0.5):
self.gateway.write_log(f'{self.underlying_symbol}合约数据{len(self.ticks)}不足{self.n} 0.7或者累计持仓数不够昨持仓0.5,暂不合成指数')
return
# 总量 > 0 # 总量 > 0
if all_interest > 0 and all_amount > 0: if all_interest > 0 and all_amount > 0:
last_price = round(float(all_amount / all_interest), 4) last_price = round(float(all_amount / all_interest), 4)