[增强功能] 普通更新

This commit is contained in:
msincenselee 2020-09-07 16:28:37 +08:00
parent 83771b8e56
commit c76202586c
7 changed files with 106 additions and 29 deletions

View File

@ -87,6 +87,7 @@ class BackTestingEngine(object):
# 引擎类型为回测
self.engine_type = EngineType.BACKTESTING
self.contract_type = 'future' # future, stock, digital
self.using_99_contract = True
# 回测策略相关
self.classes = {} # 策略类class_name: stategy_class
@ -450,6 +451,10 @@ class BackTestingEngine(object):
self.debug = test_setting.get('debug', False)
if 'using_99_contract' in test_setting:
self.using_99_contract = test_setting.get('using_99_contract')
self.write_log(f'是否使用指数合约:{self.using_99_contract}')
# 更新数据目录
if 'data_path' in test_setting:
self.data_path = test_setting.get('data_path')
@ -711,15 +716,21 @@ class BackTestingEngine(object):
symbol, exchange = extract_vt_symbol(vt_symbol)
elif self.contract_type == 'future':
symbol = vt_symbol
underly_symbol = get_underlying_symbol(symbol).upper()
if self.using_99_contract:
underly_symbol = get_underlying_symbol(symbol).upper() # WJ: 当需要回测A1701.DCE时不能替换成99合约。
exchange = self.get_exchange(f'{underly_symbol}99')
else:
exchange = self.get_exchange(symbol)
vt_symbol = '.'.join([symbol, exchange.value])
strategy_setting.update({'vt_symbol': vt_symbol})
else:
symbol = vt_symbol
exchange = Exchange.LOCAL
vt_symbol = '.'.join([symbol, exchange.value])
strategy_setting.update({'vt_symbol': vt_symbol})
# 在期货组合回测,中需要把一般配置的主力合约,更换为指数合约
# 在期货组合回测时,如果直接使用运行得配置,需要把一般配置的主力合约,更换为指数合约
if self.using_99_contract:
if '99' not in symbol and exchange != Exchange.SPD and self.contract_type == 'future':
underly_symbol = get_underlying_symbol(symbol).upper()
self.write_log(u'更新vt_symbol为指数合约:{}=>{}'.format(vt_symbol, underly_symbol + '99.' + exchange.value))
@ -2186,6 +2197,9 @@ class BackTestingEngine(object):
'test_setting': self.test_setting, # 回测参数
'strategy_setting': self.strategy_setting, # 策略参数
}
# 去除包含"."的域
if 'symbol_datas' in d['test_setting'].keys():
d['test_setting'].pop('symbol_datas')
# 保存入数据库
self.mongo_api.db_insert(

View File

@ -226,8 +226,8 @@ class PortfolioTestingEngine(BackTestingEngine):
if exchange == Exchange.SPD:
try:
active_symbol, active_rate, passive_symbol, passive_rate, spd_type = symbol.split('-')
active_vt_symbol = '.'.join([active_symbol, self.get_exchange(symbol=active_symbol)])
passive_vt_symbol = '.'.join([passive_symbol, self.get_exchange(symbol=passive_symbol)])
active_vt_symbol = '.'.join([active_symbol, self.get_exchange(symbol=active_symbol).value])
passive_vt_symbol = '.'.join([passive_symbol, self.get_exchange(symbol=passive_symbol).value])
self.load_bar_csv_to_df(active_vt_symbol, self.bar_csv_file.get(active_symbol))
self.load_bar_csv_to_df(passive_vt_symbol, self.bar_csv_file.get(passive_symbol))
except Exception as ex:
@ -282,6 +282,8 @@ class PortfolioTestingEngine(BackTestingEngine):
str_td = str(bar_data.get('trading_day', ''))
if len(str_td) == 8:
bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8]
elif len(str_td) == 10:
bar.trading_day = str_td
else:
bar.trading_day = get_trading_date(bar_datetime)
@ -473,6 +475,9 @@ def single_test(test_setting: dict, strategy_setting: dict):
# 回测结果,保存
engine.show_backtesting_result()
# 保存策略得内部数据
engine.save_strategy_data()
except Exception as ex:
print('组合回测异常{}'.format(str(ex)))
traceback.print_exc()

View File

@ -403,10 +403,14 @@ class CtaSpreadTemplate(CtaTemplate):
# 找到委托单记录
order_info = None
# 优先从活动订单获取
if trade.vt_orderid in self.active_orders.keys():
order_info = self.active_orders.get(trade.vt_orderid)
# 如果找不到,可能被移动到历史订单中,从历史订单获取
if not order_info:
if trade.vt_orderid in self.history_orders.keys():
order_info = self.history_orders.get(trade.vt_orderid)
# 找到委托记录
if order_info is not None:
# 委托单记录 =》 找到 Grid
grid = order_info.get('grid')

View File

@ -18,7 +18,7 @@ from vnpy.trader.constant import (
OrderType,
Interval,
)
from vnpy.trader.object import TickData
from vnpy.trader.object import TickData, BarData
from vnpy.trader.utility import extract_vt_symbol, get_trading_date
import pandas as pd
import csv
@ -128,7 +128,7 @@ class TqFutureData():
def __init__(self, strategy=None):
self.strategy = strategy # 传进来策略实例,这样可以写日志到策略实例
self.api = TqApi(TqSim())
self.api = TqApi(TqSim(), url="wss://u.shinnytech.com/t/md/front/mobile")
def get_tick_serial(self, vt_symbol: str):
# 获取最新的8964个数据 tick的话就相当于只有50分钟左右
@ -191,6 +191,51 @@ class TqFutureData():
return []
def get_bars(self, vt_symbol: str, start_date: datetime=None, end_date: datetime = None):
"""
获取历史bar受限于最大长度8964根bar
:param vt_symbol:
:param start_date:
:param end_date:
:return:
"""
self.write_log(f"从天勤请求合约:{vt_symbol}开始时间:{start_date}的历史1分钟bar数据")
symbol, exchange = extract_vt_symbol(vt_symbol)
# 获取一分钟数据
df = self.api.get_kline_serial(symbol=f'{exchange.value}.{symbol}', duration_seconds=60, data_length=8964)
bars = []
if df is None:
self.write_error(f'返回空白dataframe')
return []
for index, row in df.iterrows():
bar_datetime = datetime.strptime(self._nano_to_str(row['datetime']), "%Y-%m-%d %H:%M:%S.%f")
if start_date:
if bar_datetime < start_date:
continue
if end_date:
if bar_datetime > end_date:
continue
bar = BarData(
symbol=symbol,
exchange=exchange,
datetime=bar_datetime,
open_price=row['open'],
close_price=row['close'],
high_price=row['high'],
low_price=row['low'],
volume=row['volume'],
open_interest=row['close_oi'],
trading_day=get_trading_date(bar_datetime),
gateway_name='tq'
)
bars.append(bar)
return bars
def get_ticks(self, vt_symbol: str, start_date: datetime, end_date: datetime = None):
"""获取历史tick"""
@ -291,6 +336,7 @@ class TqFutureData():
else:
self.strategy.write_error(msg)
if __name__ == '__main__':
# tqsdk = Query_tqsdk_data(strategy=self) # 在策略中使用
tqsdk = TqFutureData()
@ -298,11 +344,15 @@ if __name__ == '__main__':
#tick_df = tqsdk.query_tick_history_data(vt_symbol="ni2009.SHFE", start_date=pd.to_datetime("2020-07-22"))
#print(tick_df)
ticks = tqsdk.get_runtime_ticks("ni2009.SHFE")
#ticks = tqsdk.get_runtime_ticks("ni2009.SHFE")
print(ticks[0])
#print(ticks[0])
#print(ticks[-1])
bars = tqsdk.get_bars(vt_symbol='ni2011.SHFE')
print(bars[0])
print(bars[-1])
print(ticks[-1])

View File

@ -926,6 +926,12 @@ class CtpTdApi(TdApi):
position.cur_price = self.gateway.prices.get(position.vt_symbol, None)
if position.cur_price is None:
position.cur_price = position.price
# 交易所有时候会给一些奇怪得套利合约,并且主动腿和被动腿是相同得,排除掉这些垃圾合约
if position.symbol.startswith('SP') and '&' in position.symbol:
act_symbol, pas_symbol = position.symbol.split(' ')[-1].split('&')
if act_symbol != pas_symbol:
self.gateway.subscribe(SubscribeRequest(symbol=position.symbol, exchange=position.exchange))
else:
self.gateway.subscribe(SubscribeRequest(symbol=position.symbol, exchange=position.exchange))
if last:

View File

@ -214,13 +214,6 @@ class XtpGateway(BaseGateway):
self.query_functions = [self.query_account, self.query_position]
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
def write_error(self, msg: str, error: dict) -> None:
""""""
error_id = error["error_id"]
error_msg = error["error_msg"]
msg = f"{msg},代码:{error_id},信息:{error_msg}"
self.write_log(msg)
class XtpMdApi(MdApi):

View File

@ -163,16 +163,21 @@ class PositionHolding:
self.update_order(order)
def update_trade(self, trade: TradeData) -> None:
""""""
"""更新交易"""
if trade.direction == Direction.LONG:
# 多,开仓 =》 增加今仓
if trade.offset == Offset.OPEN:
self.long_td += trade.volume
# 多,平今 =》减少今仓
elif trade.offset == Offset.CLOSETODAY:
self.short_td -= trade.volume
# 多,平昨 =》减少昨仓
elif trade.offset == Offset.CLOSEYESTERDAY:
self.short_yd -= trade.volume
# 多,平仓 =》 减少
elif trade.offset == Offset.CLOSE:
if trade.exchange in [Exchange.SHFE, Exchange.INE]:
if trade.exchange in [Exchange.SHFE, Exchange.INE] and self.short_yd >=trade.volume:
self.short_yd -= trade.volume
else:
self.short_td -= trade.volume
@ -191,7 +196,7 @@ class PositionHolding:
elif trade.offset == Offset.CLOSEYESTERDAY:
self.long_yd -= trade.volume
elif trade.offset == Offset.CLOSE:
if trade.exchange in [Exchange.SHFE, Exchange.INE]:
if trade.exchange in [Exchange.SHFE, Exchange.INE] and self.long_yd >=trade.volume:
self.long_yd -= trade.volume
else:
self.long_td -= trade.volume