From b149b6cc18ecb2220de06e144badc14b77ad4b2d Mon Sep 17 00:00:00 2001 From: msincenselee Date: Sat, 4 Sep 2021 13:48:21 +0800 Subject: [PATCH] [update] gateway & data --- examples/stock/demo_01.py | 76 +++++++++++ prod/jobs/refill_tdx_stock_bars.py | 9 +- vnpy/amqp/consumer.py | 4 +- vnpy/api/easytrader/remoteclient.py | 2 +- vnpy/api/rest/rest_client.py | 10 +- vnpy/app/account_recorder/engine.py | 2 +- vnpy/app/cta_crypto/template.py | 6 +- vnpy/app/cta_stock/template.py | 21 ++- vnpy/app/cta_strategy_pro/engine.py | 2 +- vnpy/component/cta_line_bar.py | 4 + vnpy/data/common.py | 202 ++++++++++++++++++++++++++-- vnpy/data/stock/adjust_factor.py | 19 ++- vnpy/data/tdx/tdx_common.py | 16 ++- vnpy/data/tdx/tdx_stock_data.py | 16 ++- vnpy/data/tdx/test_tdx_stock.py | 46 ++++--- vnpy/gateway/ths/ths_gateway.py | 12 +- vnpy/trader/gateway.py | 17 ++- 17 files changed, 400 insertions(+), 64 deletions(-) create mode 100644 examples/stock/demo_01.py diff --git a/examples/stock/demo_01.py b/examples/stock/demo_01.py new file mode 100644 index 00000000..c2a32f43 --- /dev/null +++ b/examples/stock/demo_01.py @@ -0,0 +1,76 @@ +# flake8: noqa + +# 示例代码 +# 从本地股票数据加载,前复权,显示主图指标、副图指标、缠论 + +import os +import sys +import json + +vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) +if vnpy_root not in sys.path: + print(f'sys.path append({vnpy_root})') + sys.path.append(vnpy_root) + +os.environ["VNPY_TESTING"] = "1" + +from vnpy.data.tdx.tdx_common import FakeStrategy +from vnpy.data.tdx.tdx_stock_data import * +from vnpy.component.cta_line_bar import CtaMinuteBar +from vnpy.trader.ui.kline.ui_snapshot import UiSnapshot +from vnpy.trader.ui import create_qapp +from vnpy.data.common import get_stock_bars + +if __name__ == "__main__": + + # 创建一个假的策略 + t1 = FakeStrategy() + + # 股票代码.交易所 + vt_symbol = '000001.SZSE' + # 数据周期 + bar_freq = '15m' + # 一根bar代表的分钟数 + bar_interval = int(bar_freq.replace('m', '')) + + # 获取某个合约得的分时数据,周期是15分钟,返回数据类型是barData + print('加载数据') + bars, msg = get_stock_bars(vt_symbol=vt_symbol, freq=bar_freq) + + # 创建一个15分钟bar的 kline对象 + setting = {} + setting['name'] = f'{vt_symbol}_{bar_freq}' + setting['bar_interval'] = bar_interval + setting['para_ma1_len'] = 55 # 双均线 + setting['para_ma2_len'] = 89 + setting['para_macd_fast_len'] = 12 # 激活macd + setting['para_macd_slow_len'] = 26 + setting['para_macd_signal_len'] = 9 + setting['para_active_chanlun'] = True # 激活缠论 + setting['price_tick'] = 1 + setting['is_stock'] = True + setting['underly_symbol'] = vt_symbol.split('.')[0] + kline = CtaMinuteBar(strategy=t1, cb_on_bar=None, setting=setting) + + # 推送bar到kline中 + for bar in bars: + kline.add_bar(bar, bar_is_completed=True, bar_freq=bar_interval) + + # 获取kline的切片数据 + data = kline.get_data() + snapshot = { + 'strategy': "demo", + 'datetime': datetime.now(), + "kline_names": [kline.name], + "klines": {kline.name: data}} + + # 创建一个GUI界面应用app + qApp = create_qapp() + + # 创建切片回放工具窗口 + ui = UiSnapshot() + # 显示切片内容 + ui.show(snapshot_file="", + d=snapshot) + + sys.exit(qApp.exec_()) diff --git a/prod/jobs/refill_tdx_stock_bars.py b/prod/jobs/refill_tdx_stock_bars.py index 5d4be31c..cf9ad920 100644 --- a/prod/jobs/refill_tdx_stock_bars.py +++ b/prod/jobs/refill_tdx_stock_bars.py @@ -147,18 +147,15 @@ def refill(symbol_info): # thread_tasks.append(task) -def resample(symbol, exchange, x_mins=[5, 15, 30]): +def resample(vt_symbol, x_mins=[5, 15, 30]): """ 更新多周期文件 - :param symbol: - :param exchange: + :param vt_symbol: 代码.交易所 :param x_mins: :return: """ d1 = datetime.now() - out_files, err_msg = resample_bars_file(vnpy_root=vnpy_root, - symbol=symbol, - exchange=exchange, + out_files, err_msg = resample_bars_file(vt_symbol=vt_symbol, x_mins=x_mins) d2 = datetime.now() microseconds = round((d2 - d1).microseconds / 100, 0) diff --git a/vnpy/amqp/consumer.py b/vnpy/amqp/consumer.py index 887d4913..c5e79a77 100644 --- a/vnpy/amqp/consumer.py +++ b/vnpy/amqp/consumer.py @@ -5,7 +5,7 @@ import pika import random import traceback from vnpy.amqp.base import base_broker - +from vnpy.component.base import MyEncoder # 模式1:接收者 class receiver(base_broker): @@ -307,7 +307,7 @@ class rpc_server(base_broker): def reply(self, chan, reply_data, reply_to, reply_id, delivery_tag): """返回调用结果""" # data => string - reply_msg = json.dumps(reply_data) + reply_msg = json.dumps(reply_data,cls=MyEncoder) # 发送返回消息 chan.basic_publish(exchange=self.exchange, routing_key=reply_to, diff --git a/vnpy/api/easytrader/remoteclient.py b/vnpy/api/easytrader/remoteclient.py index 93a3c7f8..276c63dd 100644 --- a/vnpy/api/easytrader/remoteclient.py +++ b/vnpy/api/easytrader/remoteclient.py @@ -114,7 +114,7 @@ class RemoteClient: # 整个接口对外保持和原来的一致 # 通过对原requests接口的“鸭子类型替换”来实现透明化 -def use(broker, host, port=1430, use_zmq=True, **kwargs): +def use(broker, host, port=1430, use_zmq=False, **kwargs): if use_zmq: return ZMQRemoteClient(broker, host, port) else: diff --git a/vnpy/api/rest/rest_client.py b/vnpy/api/rest/rest_client.py index 6a70b0f5..5f45ac70 100644 --- a/vnpy/api/rest/rest_client.py +++ b/vnpy/api/rest/rest_client.py @@ -140,6 +140,7 @@ class RestClient(object): self.logger: Optional[logging.Logger] = None self.proxies = None + self.cookies = {} self.thread_executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 20) @@ -436,8 +437,13 @@ class RestClient(object): if status_code == 204: json_body = None else: - json_body = response.json() + try: + json_body = response.json() + except Exception as ex: + json_body = response.content.decode('utf-8') self._process_json_body(json_body, request) + if response.cookies.get_dict(): + self.cookies.update(response.cookies) else: if request.on_failed: request.status = RequestStatus.failed @@ -462,7 +468,7 @@ class RestClient(object): else: self.on_error(t, v, tb, request) - def _process_json_body(self, json_body: Optional[dict], request: "Request"): + def _process_json_body(self, json_body: Union[dict,str], request: "Request"): status_code = request.response.status_code if self.is_request_success(json_body, request): request.status = RequestStatus.success diff --git a/vnpy/app/account_recorder/engine.py b/vnpy/app/account_recorder/engine.py index 4ce13f8b..b77cf5e4 100644 --- a/vnpy/app/account_recorder/engine.py +++ b/vnpy/app/account_recorder/engine.py @@ -243,7 +243,7 @@ class AccountRecorder(BaseEngine): end_day = dt_now.strftime('%Y%m%d') gw = self.main_engine.get_gateway(gw_name) if gw is None: - continue + self.write_log(f'Account_recorder找不到{gw_name}') if hasattr(gw, 'qryHistory'): self.write_log(u'向{}请求{}数据,{}~{}'.format(gw_name, data_type, begin_day, end_day)) gw.qryHistory(data_type, begin_day, end_day) diff --git a/vnpy/app/cta_crypto/template.py b/vnpy/app/cta_crypto/template.py index 73b8cb66..35ee74cd 100644 --- a/vnpy/app/cta_crypto/template.py +++ b/vnpy/app/cta_crypto/template.py @@ -20,7 +20,7 @@ from .base import StopOrder from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade from vnpy.component.cta_position import CtaPosition from vnpy.component.cta_policy import CtaPolicy - +from vnpy.component.base import MyEncoder class CtaTemplate(ABC): """CTA策略模板""" @@ -1368,7 +1368,7 @@ class CtaFutureTemplate(CtaTemplate): if policy: op = getattr(policy, 'to_json', None) if callable(op): - self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False))) + self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False,cls=MyEncoder))) def save_dist(self, dist_data): """ @@ -2152,7 +2152,7 @@ class CtaSpotTemplate(CtaTemplate): if policy: op = getattr(policy, 'to_json', None) if callable(op): - self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False))) + self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False,cls=MyEncoder))) def save_dist(self, dist_data): """ diff --git a/vnpy/app/cta_stock/template.py b/vnpy/app/cta_stock/template.py index 9c88fe25..c80c5e65 100644 --- a/vnpy/app/cta_stock/template.py +++ b/vnpy/app/cta_stock/template.py @@ -20,6 +20,7 @@ from .base import StopOrder,EngineType from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade from vnpy.component.cta_position import CtaPosition from vnpy.component.cta_policy import CtaPolicy +from vnpy.component.base import MyEncoder class CtaTemplate(ABC): """CTA股票策略模板""" @@ -602,7 +603,7 @@ class CtaStockTemplate(CtaTemplate): """初始化Policy""" self.write_log(u'init_policy(),初始化执行逻辑') self.policy.load() - self.write_log('{}'.format(json.dumps(self.policy.to_json(),indent=2, ensure_ascii=False))) + self.write_log('{}'.format(json.dumps(self.policy.to_json(),indent=2, ensure_ascii=False,cls=MyEncoder))) def init_position(self): """ @@ -1075,9 +1076,12 @@ class CtaStockTemplate(CtaTemplate): continue # 实盘运行时,要加入市场买卖量的判断 + limit_down = None if not force and not self.backtesting: symbol_tick = self.cta_engine.get_tick(vt_symbol) if symbol_tick: + if symbol_tick.limit_down > 0: + limit_down = symbol_tick.limit_down symbol_volume_tick = self.cta_engine.get_volume_tick(vt_symbol) # 根据市场计算,前5档买单数量 if all([symbol_tick.ask_volume_1, symbol_tick.ask_volume_2, symbol_tick.ask_volume_3, @@ -1095,7 +1099,11 @@ class CtaStockTemplate(CtaTemplate): self.write_log(u'修正批次卖出{}数量:{}=>{}'.format(vt_symbol, org_sell_volume, sell_volume)) # 获取当前价格 - sell_price = cur_price - self.cta_engine.get_price_tick(vt_symbol) + if limit_down is None or cur_price > limit_down: + sell_price = cur_price - self.cta_engine.get_price_tick(vt_symbol) + else: + sell_price = cur_price + # 发出委托卖出 vt_orderids = self.sell( vt_symbol=vt_symbol, @@ -1134,7 +1142,7 @@ class CtaStockTemplate(CtaTemplate): dist_record = dict() dist_record['volume'] = grid.volume dist_record['price'] = self.cta_engine.get_price(grid.vt_symbol) - dist_record['operation'] = 'execute finished' + dist_record['operation'] = 'sell finished' dist_record['signal'] = grid.type self.save_dist(dist_record) @@ -1322,6 +1330,11 @@ class CtaStockTemplate(CtaTemplate): elif order_status == Status.CANCELLED: self.write_log(u'委托单{}已成功撤单,删除{}'.format(vt_orderid, order_info)) canceled_ids.append(vt_orderid) + elif order_status == Status.CANCELLING: + + if over_seconds > self.cancel_seconds * 3: + self.write_log(u'委托单{}正在撤单,超时{},删除{}'.format(vt_orderid,over_seconds, order_info)) + canceled_ids.append(vt_orderid) # 删除撤单的订单 for vt_orderid in canceled_ids: @@ -1381,7 +1394,7 @@ class CtaStockTemplate(CtaTemplate): policy = getattr(self, 'policy') op = getattr(policy, 'to_json', None) if callable(op): - self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False))) + self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False,cls=MyEncoder))) def save_dist(self, dist_data): """ diff --git a/vnpy/app/cta_strategy_pro/engine.py b/vnpy/app/cta_strategy_pro/engine.py index 35bce4c1..adb06c2b 100644 --- a/vnpy/app/cta_strategy_pro/engine.py +++ b/vnpy/app/cta_strategy_pro/engine.py @@ -76,7 +76,7 @@ from .base import ( STOPORDER_PREFIX, ) from .template import CtaTemplate -from vnpy.component.base import MARKET_DAY_ONLY +from vnpy.component.base import MARKET_DAY_ONLY, MyEncoder from vnpy.component.cta_position import CtaPosition STOP_STATUS_MAP = { diff --git a/vnpy/component/cta_line_bar.py b/vnpy/component/cta_line_bar.py index 9149c980..2208975e 100644 --- a/vnpy/component/cta_line_bar.py +++ b/vnpy/component/cta_line_bar.py @@ -165,6 +165,7 @@ class CtaLineBar(object): self.price_tick = 1 # 商品的最小价格单位 self.round_n = 4 # round() 小数点的截断数量 self.is_7x24 = False # 是否7x24小时运行( 一般为数字货币) + self.is_stock = False # 是否为股票 # 当前的Tick的信息 self.cur_tick = None # 当前 onTick()函数接收的 最新的tick @@ -230,6 +231,8 @@ class CtaLineBar(object): self.minute_interval = None # 把各个周期的bar转换为分钟,在first_tick中,用来修正bar为整点分钟周期 if setting: self.set_params(setting) + if self.is_stock: + self.is_7x24 = True # 修正self.minute_interval if self.interval == Interval.SECOND: @@ -283,6 +286,7 @@ class CtaLineBar(object): self.param_list.append('interval') # bar的类型 self.param_list.append('mode') # tick/bar模式 self.param_list.append('is_7x24') # 是否为7X24小时运行的bar(一般为数字货币) + self.param_list.append('is_stock') # 是否为7X24小时运行的bar(一般为数字货币) self.param_list.append('price_tick') # 最小跳动,用于处理指数等不一致的价格 self.param_list.append('underly_symbol') # 短合约, diff --git a/vnpy/data/common.py b/vnpy/data/common.py index 1a22bdb0..525f86ec 100644 --- a/vnpy/data/common.py +++ b/vnpy/data/common.py @@ -1,20 +1,196 @@ import os import pandas as pd +import numpy as np +from typing import Union, List +from datetime import datetime + +# 所有股票的复权因子 +STOCK_ADJUST_FACTORS = {} + +def get_bardata_folder(data_folder: str) -> str: + """ + 如果data_folder为空白,就返回bar_data的目录 + :param data_folder: + :return: + """ + if len(data_folder) == 0 or not os.path.exists(data_folder): + vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) + data_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data')) + return data_folder + +def get_stock_bars(vt_symbol:str, + freq: str = "1d", + start_date: str = "", + fq_type:str ="qfq") -> (List, str): + """ + 获取本地文件的股票bar数据 + :param vt_symbol: + :param freq: + :param start_date: 20180101 或者 2018-01-01 + :param fq_type: qfq:前复权;hfq:后复权; 空白:不复权 + :return: + """ + # 获取未复权的bar dataframe数据 + df, err_msg = get_stock_raw_data(vt_symbol=vt_symbol, freq=freq, start_date=start_date) + bars = [] + if len(err_msg) > 0 or df is None: + return bars, err_msg + + if fq_type != "": + from vnpy.data.stock.adjust_factor import get_all_adjust_factor + STOCK_ADJUST_FACTORS = get_all_adjust_factor() + adj_list = STOCK_ADJUST_FACTORS.get(vt_symbol, []) + + if len(adj_list) > 0: + + for row in adj_list: + row.update({'dividOperateDate': row.get('dividOperateDate')[:10] + ' 09:30:00'}) + # list -> dataframe, 转换复权日期格式 + adj_data = pd.DataFrame(adj_list) + adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S") + adj_data = adj_data.set_index("dividOperateDate") + # 调用转换方法,对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权 + df = stock_to_adj(df, adj_data, adj_type='fore' if fq_type == 'qfw' else 'back') + + from vnpy.trader.object import BarData + from vnpy.trader.constant import Exchange + symbol, exchange = vt_symbol.split('.') + + for dt, bar_data in df.iterrows(): + bar_datetime = dt # - timedelta(seconds=bar_interval_seconds) + + bar = BarData( + gateway_name='backtesting', + symbol=symbol, + exchange=Exchange(exchange), + datetime=bar_datetime + ) + if 'open' in bar_data: + bar.open_price = float(bar_data['open']) + bar.close_price = float(bar_data['close']) + bar.high_price = float(bar_data['high']) + bar.low_price = float(bar_data['low']) + else: + bar.open_price = float(bar_data['open_price']) + bar.close_price = float(bar_data['close_price']) + bar.high_price = float(bar_data['high_price']) + bar.low_price = float(bar_data['low_price']) + + bar.volume = int(bar_data['volume']) if not np.isnan(bar_data['volume']) else 0 + bar.date = dt.strftime('%Y-%m-%d') + bar.time = dt.strftime('%H:%M:%S') + str_td = str(bar_data.get('trading_day', '')) + if len(str_td) == 8: + bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8] + else: + bar.trading_day = bar.date + + bars.append(bar) + + return bars, "" + +def get_stock_raw_data(vt_symbol: str, + freq: str = "1d", + start_date: str = "", + bar_data_folder: str = "") -> (Union[pd.DataFrame, None], str): + """ + 获取本地bar_data下的 交易所/股票代码_时间周期.csv原始bar数据(未复权) + :param vt_symbol: 600001.SSE 或 600001 + :param freq: 1m,5m, 15m, 30m, 1h, 1d + :param start_date: 开始日期 + :param bar_data_folder: 强制指定bar_data所在目录 + :return: DataFrame, err_msg + """ + symbol, exchange = vt_symbol.split('.') + # 1分钟 csv文件路径 + csv_file = os.path.abspath(os.path.join( + get_bardata_folder(bar_data_folder), + exchange, + f'{symbol}_{freq}.csv')) + + if not os.path.exists(csv_file): + err_msg = f'{csv_file} 文件不存在,不能读取' + return None, err_msg + try: + # 载入原始csv => dataframe + df = pd.read_csv(csv_file) + + datetime_format = "%Y-%m-%d %H:%M:%S" + # 转换时间,str =》 datetime + df["datetime"] = pd.to_datetime(df["datetime"], format=datetime_format) + # 使用'datetime'字段作为索引 + df.set_index("datetime", inplace=True) + if len(start_date) > 0: + if len(start_date) == 8: + _format = '%Y%m%d' + else: + _format = '%Y-%m-%d' + start_date = datetime.strptime(start_date, _format) + df = df.loc[start_date:] + + return df, "" + + except Exception as ex: + err_msg = f'读取异常:{str(ex)}' + return None, err_msg -def resample_bars_file(vnpy_root, symbol, exchange, x_mins=[], include_day=False): +def stock_to_adj(raw_data: pd.DataFrame, + adj_data: pd.DataFrame, + adj_type: str) -> pd.DataFrame: + """ + 股票数据复权转换 + :param raw_data: 不复权数据 + :param adj_data: 复权记录 ( 从barstock下载的复权记录列表=》df) + :param adj_type: 复权类型, fore, 前复权; back,后复权 + :return: + """ + + if adj_type == 'fore': + adj_factor = adj_data["foreAdjustFactor"] + adj_factor = adj_factor / adj_factor.iloc[-1] # 保证最后一个复权因子是1 + else: + adj_factor = adj_data["backAdjustFactor"] + adj_factor = adj_factor / adj_factor.iloc[0] # 保证第一个复权因子是1 + + # 把raw_data的第一个日期,插入复权因子df,使用后填充 + if adj_factor.index[0] != raw_data.index[0]: + adj_factor.loc[raw_data.index[0]] = np.nan + adj_factor.sort_index(inplace=True) + adj_factor = adj_factor.ffill() + + adj_factor = adj_factor.reindex(index=raw_data.index) # 按价格dataframe的日期索引来扩展索引 + adj_factor = adj_factor.ffill() # 向前(向未来)填充扩展后的空单元格 + + # 把复权因子,作为adj字段,补充到raw_data中 + raw_data['adj'] = adj_factor + + # 逐一复权高低开平和成交量 + for col in ['open', 'high', 'low', 'close']: + raw_data[col] = raw_data[col] * raw_data['adj'] # 价格乘上复权系数 + raw_data['volume'] = raw_data['volume'] / raw_data['adj'] # 成交量除以复权系数 + + return raw_data + + +def resample_bars_file(vt_symbol: str, + x_mins: List[str] = [], + include_day: bool = False, + bar_data_folder: str = "") -> (list, str): """ 重建x分钟K线(和日线)csv文件 - :param symbol: + :param vt_symbol: 代码.交易所 :param x_mins: [5, 15, 30, 60] :param include_day: 是否也重建日线 + :param vnpy_root: 项目所在根目录 :return: out_files,err_msg """ err_msg = "" out_files = [] + symbol, exchange = vt_symbol.split('.') # 1分钟 csv文件路径 - csv_file = os.path.abspath(os.path.join(vnpy_root, 'bar_data', exchange.value, f'{symbol}_1m.csv')) + csv_file = os.path.abspath(os.path.join(get_bardata_folder(bar_data_folder), exchange, f'{symbol}_1m.csv')) if not os.path.exists(csv_file): err_msg = f'{csv_file} 文件不存在,不能转换' @@ -49,11 +225,14 @@ def resample_bars_file(vnpy_root, symbol, exchange, x_mins=[], include_day=False for x_min in x_mins: # 目标文件 target_file = os.path.abspath( - os.path.join(vnpy_root, 'bar_data', exchange.value, f'{symbol}_{x_min}m.csv')) + os.path.join(get_bardata_folder(bar_data_folder), exchange, f'{symbol}_{x_min}m.csv')) # 合成x分钟K线并删除为空的行 参数 closed:left类似向上取值既 09:30的k线数据是包含09:30-09:35之间的数据 - #df_target = df_1m.resample(f'{x_min}min', how=ohlc_rule, closed='left', label='left').dropna(axis=0, how='any') - df_target = df_1m.resample(f'{x_min}min', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, - how='any') + # df_target = df_1m.resample(f'{x_min}min', how=ohlc_rule, closed='left', label='left').dropna(axis=0, how='any') + df_target = df_1m.resample( + f'{x_min}min', + closed='left', + label='left').agg(ohlc_rule).dropna(axis=0, + how='any') # dropna(axis=0, how='any') axis参数0:针对行进行操作 1:针对列进行操作 how参数any:只要包含就删除 all:全是为NaN才删除 if len(df_target) > 0: @@ -64,10 +243,13 @@ def resample_bars_file(vnpy_root, symbol, exchange, x_mins=[], include_day=False if include_day: # 目标文件 target_file = os.path.abspath( - os.path.join(vnpy_root, 'bar_data', exchange.value, f'{symbol}_1d.csv')) + os.path.join(get_bardata_folder(bar_data_folder), exchange, f'{symbol}_1d.csv')) # 合成x分钟K线并删除为空的行 参数 closed:left类似向上取值既 09:30的k线数据是包含09:30-09:35之间的数据 # df_target = df_1m.resample(f'D', how=ohlc_rule, closed='left', label='left').dropna(axis=0, how='any') - df_target = df_1m.resample(f'D', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, how='any') + df_target = df_1m.resample( + f'D', + closed='left', + label='left').agg(ohlc_rule).dropna(axis=0, how='any') # dropna(axis=0, how='any') axis参数0:针对行进行操作 1:针对列进行操作 how参数any:只要包含就删除 all:全是为NaN才删除 if len(df_target) > 0: @@ -75,4 +257,4 @@ def resample_bars_file(vnpy_root, symbol, exchange, x_mins=[], include_day=False print(f'生成[日线] => {target_file}') out_files.append(target_file) - return out_files,err_msg + return out_files, err_msg diff --git a/vnpy/data/stock/adjust_factor.py b/vnpy/data/stock/adjust_factor.py index 64c0e1a7..c3d3fa6c 100644 --- a/vnpy/data/stock/adjust_factor.py +++ b/vnpy/data/stock/adjust_factor.py @@ -130,4 +130,21 @@ def download_adjust_factor(): return factor_dict if __name__ == '__main__': - download_adjust_factor() + + # 下载所有复权数据 + # download_adjust_factor() + + # 下载某个股票的复权数据 + # f = get_adjust_factor(vt_symbol='000651.SZSE',stock_name='格力电器',need_login=True) + # + # for d in f: + # print(d) + + # 读取缓存文件中某只股票的复权数据 + factors = get_all_adjust_factor() + f = factors.get('000651.SZSE',None) + if f is None: + print('获取不到数据') + else: + for d in f: + print(d) diff --git a/vnpy/data/tdx/tdx_common.py b/vnpy/data/tdx/tdx_common.py index 496b2abd..6c2df26d 100644 --- a/vnpy/data/tdx/tdx_common.py +++ b/vnpy/data/tdx/tdx_common.py @@ -34,7 +34,7 @@ TDX_PROXY_CONFIG = 'tdx_proxy_config.json' def get_tdx_market_code(code): # 获取通达信股票的market code code = str(code) - if code[0] in ['5', '6', '9'] or code[:3] in ["009", "126", "110", "201", "202", "203", "204"]: + if code[0] in ['5', '6', '9'] or code[:3] in ["880","009", "126", "110", "201", "202", "203", "204"]: # 上海证券交易所 return 1 # 深圳证券交易所 @@ -101,11 +101,14 @@ def get_cache_config(config_file_name): config = {} if not os.path.exists(config_file_name): return config - with bz2.BZ2File(config_file_name, 'rb') as f: - config = pickle.load(f) + try: + with bz2.BZ2File(config_file_name, 'rb') as f: + config = pickle.load(f) + return config + except Exception as ex: + print(f'读取缓存本地文件:{config_file_name}异常{str(ex)}') return config - def save_cache_config(data: dict, config_file_name): """保存本地缓存的配置地址信息""" config_file_name = os.path.abspath(os.path.join(os.path.dirname(__file__), config_file_name)) @@ -126,9 +129,10 @@ def save_cache_json(data_dict: dict, json_file_name: str): save_json(filename=config_file_name, data=data_dict) -def get_stock_type(code): +def get_stock_type(code,market_id = None ): """获取股票得分类""" - market_id = get_tdx_market_code(code) + if market_id is None: + market_id = get_tdx_market_code(code) if market_id == 0: return get_stock_type_sz(code) diff --git a/vnpy/data/tdx/tdx_stock_data.py b/vnpy/data/tdx/tdx_stock_data.py index 5c946981..f293544c 100644 --- a/vnpy/data/tdx/tdx_stock_data.py +++ b/vnpy/data/tdx/tdx_stock_data.py @@ -314,13 +314,18 @@ class TdxStockData(object): self.write_log('{}开始下载tdx股票: {},代码:{} {}数据, {} to {}.' .format(datetime.now(), name, tdx_code, tdx_period, qry_start_date, qry_end_date)) + stock_type = get_stock_type(tdx_code,market_id) + if stock_type == 'index_cn': + get_bar_func = self.api.get_index_bars + else: + get_bar_func = self.api.get_security_bars try: _start_date = qry_end_date _bars = [] _pos = 0 while _start_date > qry_start_date: - _res = self.api.get_security_bars( + _res = get_bar_func( category=PERIOD_MAPPING[period], market=market_id, code=tdx_code, @@ -452,8 +457,13 @@ class TdxStockData(object): .format(datetime.now(), period, list(PERIOD_MAPPING.keys()))) return False, ret_bars tdx_period = PERIOD_MAPPING.get(period) + stock_type = get_stock_type(tdx_code) + if stock_type == 'index_cn': + get_bar_func = self.api.get_index_bars + else: + get_bar_func = self.api.get_security_bars try: - datas = self.api.get_security_bars( + datas = get_bar_func( category=PERIOD_MAPPING[period], market=market_id, code=tdx_code, @@ -490,6 +500,8 @@ class TdxStockData(object): self.write_error(f'获取{symbol}数据失败:{str(ex)}') return False, ret_bars + # ---------------------------------------------------------------------- + def save_cache(self, cache_folder: str, cache_symbol: str, diff --git a/vnpy/data/tdx/test_tdx_stock.py b/vnpy/data/tdx/test_tdx_stock.py index 49c372c5..089e2cc6 100644 --- a/vnpy/data/tdx/test_tdx_stock.py +++ b/vnpy/data/tdx/test_tdx_stock.py @@ -19,26 +19,36 @@ t1 = FakeStrategy() t2 = FakeStrategy() # 创建API对象(使用本地socket5代理) -api_01 = TdxStockData(strategy=t1, proxy_ip='localhost', proxy_port=1080) +#api_01 = TdxStockData(strategy=t1, proxy_ip='localhost', proxy_port=1080) # 不使用代理 -#api_01 = TdxStockData(strategy=t1) - -# 获取市场下股票 -for market_id in range(2): - print('get market_id:{}'.format(market_id)) - security_list = api_01.get_security_list(market_id) - if len(security_list) == 0: - continue - for security in security_list: - if security.get('code', '').startswith('12') or u'转债' in security.get('name', ''): - str_security = json.dumps(security, indent=1, ensure_ascii=False) - print('market_id:{},{}'.format(market_id, str_security)) +api_01 = TdxStockData(strategy=t1) +# +# # 获取市场下股票 +# for market_id in range(2): +# print('get market_id:{}'.format(market_id)) +# security_list = api_01.get_security_list(market_id) +# if len(security_list) == 0: +# continue +# for security in security_list: +# if security['code'].startswith('88'): +# str_security = json.dumps(security, indent=1, ensure_ascii=False) +# print(str_security) +# if security.get('code', '').startswith('12') or u'转债' in security.get('name', ''): +# str_security = json.dumps(security, indent=1, ensure_ascii=False) +# # print('market_id:{},{}'.format(market_id, str_security)) # str_markets = json.dumps(security_list, indent=1, ensure_ascii=False) # print(u'{}'.format(str_markets)) # 获取历史分钟线 -# api_01.get_bars('002024', period='1hour', callback=t1.display_bar) +ret,result = api_01.get_bars('880351.SSE', period='1hour', callback=t1.display_bar) +if ret: + for bar in result: + print(bar) +# ret,result = api_01.get_last_bars(symbol='002024',return_bar=True) +# if ret: +# print(result) + # api.get_bars(symbol, period='5min', callback=display_bar) # api.get_bars(symbol, period='1day', callback=display_bar) @@ -50,7 +60,7 @@ for market_id in range(2): # for r in result[0:10] + result[-10:]: # print(r) -# 获取历史分时数据 -ret, result = api_01.get_history_transaction_data('110031', '20200504') -for r in result[0:10] + result[-10:]: - print(r) +# # 获取历史分时数据 +# ret, result = api_01.get_history_transaction_data('110031', '20200504') +# for r in result[0:10] + result[-10:]: +# print(r) diff --git a/vnpy/gateway/ths/ths_gateway.py b/vnpy/gateway/ths/ths_gateway.py index 6d20ef9b..274c44ce 100644 --- a/vnpy/gateway/ths/ths_gateway.py +++ b/vnpy/gateway/ths/ths_gateway.py @@ -940,11 +940,21 @@ class ThsTdApi(object): return if '总资产' not in data: return + + ## 为了兼容东财的webapi,这里frozen做个特殊处理 + # if "冻结金额" in data: + # # 同花顺直接给了冻结金额 + # frozen = float(data["冻结金额"]) + # else: + # # 东财没有冻结金额这个项目,要计算 + # frozen = float(data["总资产"]) - float(data["资金余额"]) + frozen = 0 + account = AccountData( gateway_name=self.gateway_name, accountid=self.userid, balance=float(data["总资产"]), - frozen=float(data["总资产"]) - float(data["资金余额"]), + frozen=frozen, currency="人民币", trading_day=self.trading_day ) diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index 8884fe9e..61b2de52 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -578,16 +578,20 @@ class IndexGenerator: self.exchange = setting.get('exchange', None) self.price_tick = setting.get('price_tick') self.symbols = setting.get('symbols', {}) + self.pre_oi_total = 1 + # 订阅行情 self.subscribe() self.n = len(self.symbols) + def subscribe(self): """订阅行情""" dt_now = datetime.now() for symbol in list(self.symbols.keys()): pre_open_interest = self.symbols.get(symbol,0) + self.pre_oi_total += pre_open_interest # 全路径合约 => 标准合约 ,如 ZC2109 => ZC109, RB2110 => rb2110 vn_symbol = get_real_symbol_by_exchange(symbol, Exchange(self.exchange)) # 先移除 @@ -596,6 +600,9 @@ class IndexGenerator: self.gateway.write_log(f'移除早于当月的合约{symbol}') continue + if pre_open_interest < 100: + self.gateway.write_log(f'移除持仓量:{pre_open_interest}低于100的合约{symbol}') + continue # 重新登记合约 self.symbols[vn_symbol] = pre_open_interest @@ -625,12 +632,6 @@ class IndexGenerator: bid_price_1 = 0 mi_tick = None - # 已经积累的行情tick数量,不足总数减1,不处理 - - if len(self.ticks) < min(self.n * 0.8, 3): - self.gateway.write_log(f'{self.underlying_symbol}合约数据{len(self.ticks)}不足{self.n} 0.8,暂不合成指数') - return - # 计算所有合约的累加持仓量、资金、成交量、找出最大持仓量的主力合约 for t in self.ticks.values(): all_interest += t.open_interest @@ -641,6 +642,10 @@ class IndexGenerator: if mi_tick is None or mi_tick.open_interest < t.open_interest: mi_tick = t + if not (len(self.ticks) > min(self.n * 0.7, 3) or all_interest > self.pre_oi_total * 0.5): + self.gateway.write_log(f'{self.underlying_symbol}合约数据{len(self.ticks)}不足{self.n} 0.7,或者累计持仓数不够昨持仓0.5,暂不合成指数') + return + # 总量 > 0 if all_interest > 0 and all_amount > 0: last_price = round(float(all_amount / all_interest), 4)