[增强功能] 普通更新
This commit is contained in:
parent
83771b8e56
commit
c76202586c
@ -87,6 +87,7 @@ class BackTestingEngine(object):
|
||||
# 引擎类型为回测
|
||||
self.engine_type = EngineType.BACKTESTING
|
||||
self.contract_type = 'future' # future, stock, digital
|
||||
self.using_99_contract = True
|
||||
|
||||
# 回测策略相关
|
||||
self.classes = {} # 策略类,class_name: stategy_class
|
||||
@ -450,6 +451,10 @@ class BackTestingEngine(object):
|
||||
|
||||
self.debug = test_setting.get('debug', False)
|
||||
|
||||
if 'using_99_contract' in test_setting:
|
||||
self.using_99_contract = test_setting.get('using_99_contract')
|
||||
self.write_log(f'是否使用指数合约:{self.using_99_contract}')
|
||||
|
||||
# 更新数据目录
|
||||
if 'data_path' in test_setting:
|
||||
self.data_path = test_setting.get('data_path')
|
||||
@ -711,15 +716,21 @@ class BackTestingEngine(object):
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
elif self.contract_type == 'future':
|
||||
symbol = vt_symbol
|
||||
underly_symbol = get_underlying_symbol(symbol).upper()
|
||||
if self.using_99_contract:
|
||||
underly_symbol = get_underlying_symbol(symbol).upper() # WJ: 当需要回测A1701.DCE时,不能替换成99合约。
|
||||
exchange = self.get_exchange(f'{underly_symbol}99')
|
||||
else:
|
||||
exchange = self.get_exchange(symbol)
|
||||
vt_symbol = '.'.join([symbol, exchange.value])
|
||||
strategy_setting.update({'vt_symbol': vt_symbol})
|
||||
else:
|
||||
symbol = vt_symbol
|
||||
exchange = Exchange.LOCAL
|
||||
vt_symbol = '.'.join([symbol, exchange.value])
|
||||
strategy_setting.update({'vt_symbol': vt_symbol})
|
||||
|
||||
# 在期货组合回测,中需要把一般配置的主力合约,更换为指数合约
|
||||
# 在期货组合回测时,如果直接使用运行得配置,需要把一般配置的主力合约,更换为指数合约
|
||||
if self.using_99_contract:
|
||||
if '99' not in symbol and exchange != Exchange.SPD and self.contract_type == 'future':
|
||||
underly_symbol = get_underlying_symbol(symbol).upper()
|
||||
self.write_log(u'更新vt_symbol为指数合约:{}=>{}'.format(vt_symbol, underly_symbol + '99.' + exchange.value))
|
||||
@ -2186,6 +2197,9 @@ class BackTestingEngine(object):
|
||||
'test_setting': self.test_setting, # 回测参数
|
||||
'strategy_setting': self.strategy_setting, # 策略参数
|
||||
}
|
||||
# 去除包含"."的域
|
||||
if 'symbol_datas' in d['test_setting'].keys():
|
||||
d['test_setting'].pop('symbol_datas')
|
||||
|
||||
# 保存入数据库
|
||||
self.mongo_api.db_insert(
|
||||
|
@ -226,8 +226,8 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
if exchange == Exchange.SPD:
|
||||
try:
|
||||
active_symbol, active_rate, passive_symbol, passive_rate, spd_type = symbol.split('-')
|
||||
active_vt_symbol = '.'.join([active_symbol, self.get_exchange(symbol=active_symbol)])
|
||||
passive_vt_symbol = '.'.join([passive_symbol, self.get_exchange(symbol=passive_symbol)])
|
||||
active_vt_symbol = '.'.join([active_symbol, self.get_exchange(symbol=active_symbol).value])
|
||||
passive_vt_symbol = '.'.join([passive_symbol, self.get_exchange(symbol=passive_symbol).value])
|
||||
self.load_bar_csv_to_df(active_vt_symbol, self.bar_csv_file.get(active_symbol))
|
||||
self.load_bar_csv_to_df(passive_vt_symbol, self.bar_csv_file.get(passive_symbol))
|
||||
except Exception as ex:
|
||||
@ -282,6 +282,8 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
str_td = str(bar_data.get('trading_day', ''))
|
||||
if len(str_td) == 8:
|
||||
bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8]
|
||||
elif len(str_td) == 10:
|
||||
bar.trading_day = str_td
|
||||
else:
|
||||
bar.trading_day = get_trading_date(bar_datetime)
|
||||
|
||||
@ -473,6 +475,9 @@ def single_test(test_setting: dict, strategy_setting: dict):
|
||||
# 回测结果,保存
|
||||
engine.show_backtesting_result()
|
||||
|
||||
# 保存策略得内部数据
|
||||
engine.save_strategy_data()
|
||||
|
||||
except Exception as ex:
|
||||
print('组合回测异常{}'.format(str(ex)))
|
||||
traceback.print_exc()
|
||||
|
@ -403,10 +403,14 @@ class CtaSpreadTemplate(CtaTemplate):
|
||||
|
||||
# 找到委托单记录
|
||||
order_info = None
|
||||
# 优先从活动订单获取
|
||||
if trade.vt_orderid in self.active_orders.keys():
|
||||
order_info = self.active_orders.get(trade.vt_orderid)
|
||||
# 如果找不到,可能被移动到历史订单中,从历史订单获取
|
||||
if not order_info:
|
||||
if trade.vt_orderid in self.history_orders.keys():
|
||||
order_info = self.history_orders.get(trade.vt_orderid)
|
||||
# 找到委托记录
|
||||
if order_info is not None:
|
||||
# 委托单记录 =》 找到 Grid
|
||||
grid = order_info.get('grid')
|
||||
|
@ -18,7 +18,7 @@ from vnpy.trader.constant import (
|
||||
OrderType,
|
||||
Interval,
|
||||
)
|
||||
from vnpy.trader.object import TickData
|
||||
from vnpy.trader.object import TickData, BarData
|
||||
from vnpy.trader.utility import extract_vt_symbol, get_trading_date
|
||||
import pandas as pd
|
||||
import csv
|
||||
@ -128,7 +128,7 @@ class TqFutureData():
|
||||
def __init__(self, strategy=None):
|
||||
self.strategy = strategy # 传进来策略实例,这样可以写日志到策略实例
|
||||
|
||||
self.api = TqApi(TqSim())
|
||||
self.api = TqApi(TqSim(), url="wss://u.shinnytech.com/t/md/front/mobile")
|
||||
|
||||
def get_tick_serial(self, vt_symbol: str):
|
||||
# 获取最新的8964个数据 tick的话就相当于只有50分钟左右
|
||||
@ -191,6 +191,51 @@ class TqFutureData():
|
||||
|
||||
return []
|
||||
|
||||
def get_bars(self, vt_symbol: str, start_date: datetime=None, end_date: datetime = None):
|
||||
"""
|
||||
获取历史bar(受限于最大长度8964根bar)
|
||||
:param vt_symbol:
|
||||
:param start_date:
|
||||
:param end_date:
|
||||
:return:
|
||||
"""
|
||||
|
||||
self.write_log(f"从天勤请求合约:{vt_symbol}开始时间:{start_date}的历史1分钟bar数据")
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
|
||||
# 获取一分钟数据
|
||||
df = self.api.get_kline_serial(symbol=f'{exchange.value}.{symbol}', duration_seconds=60, data_length=8964)
|
||||
bars = []
|
||||
if df is None:
|
||||
self.write_error(f'返回空白dataframe')
|
||||
return []
|
||||
|
||||
for index, row in df.iterrows():
|
||||
bar_datetime = datetime.strptime(self._nano_to_str(row['datetime']), "%Y-%m-%d %H:%M:%S.%f")
|
||||
if start_date:
|
||||
if bar_datetime < start_date:
|
||||
continue
|
||||
if end_date:
|
||||
if bar_datetime > end_date:
|
||||
continue
|
||||
bar = BarData(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
datetime=bar_datetime,
|
||||
open_price=row['open'],
|
||||
close_price=row['close'],
|
||||
high_price=row['high'],
|
||||
low_price=row['low'],
|
||||
volume=row['volume'],
|
||||
open_interest=row['close_oi'],
|
||||
trading_day=get_trading_date(bar_datetime),
|
||||
gateway_name='tq'
|
||||
)
|
||||
bars.append(bar)
|
||||
|
||||
return bars
|
||||
|
||||
|
||||
def get_ticks(self, vt_symbol: str, start_date: datetime, end_date: datetime = None):
|
||||
"""获取历史tick"""
|
||||
|
||||
@ -291,6 +336,7 @@ class TqFutureData():
|
||||
else:
|
||||
self.strategy.write_error(msg)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# tqsdk = Query_tqsdk_data(strategy=self) # 在策略中使用
|
||||
tqsdk = TqFutureData()
|
||||
@ -298,11 +344,15 @@ if __name__ == '__main__':
|
||||
#tick_df = tqsdk.query_tick_history_data(vt_symbol="ni2009.SHFE", start_date=pd.to_datetime("2020-07-22"))
|
||||
#print(tick_df)
|
||||
|
||||
ticks = tqsdk.get_runtime_ticks("ni2009.SHFE")
|
||||
#ticks = tqsdk.get_runtime_ticks("ni2009.SHFE")
|
||||
|
||||
print(ticks[0])
|
||||
#print(ticks[0])
|
||||
|
||||
#print(ticks[-1])
|
||||
bars = tqsdk.get_bars(vt_symbol='ni2011.SHFE')
|
||||
print(bars[0])
|
||||
print(bars[-1])
|
||||
|
||||
print(ticks[-1])
|
||||
|
||||
|
||||
|
||||
|
@ -926,6 +926,12 @@ class CtpTdApi(TdApi):
|
||||
position.cur_price = self.gateway.prices.get(position.vt_symbol, None)
|
||||
if position.cur_price is None:
|
||||
position.cur_price = position.price
|
||||
# 交易所有时候会给一些奇怪得套利合约,并且主动腿和被动腿是相同得,排除掉这些垃圾合约
|
||||
if position.symbol.startswith('SP') and '&' in position.symbol:
|
||||
act_symbol, pas_symbol = position.symbol.split(' ')[-1].split('&')
|
||||
if act_symbol != pas_symbol:
|
||||
self.gateway.subscribe(SubscribeRequest(symbol=position.symbol, exchange=position.exchange))
|
||||
else:
|
||||
self.gateway.subscribe(SubscribeRequest(symbol=position.symbol, exchange=position.exchange))
|
||||
|
||||
if last:
|
||||
|
@ -214,13 +214,6 @@ class XtpGateway(BaseGateway):
|
||||
self.query_functions = [self.query_account, self.query_position]
|
||||
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
|
||||
|
||||
def write_error(self, msg: str, error: dict) -> None:
|
||||
""""""
|
||||
error_id = error["error_id"]
|
||||
error_msg = error["error_msg"]
|
||||
msg = f"{msg},代码:{error_id},信息:{error_msg}"
|
||||
self.write_log(msg)
|
||||
|
||||
|
||||
class XtpMdApi(MdApi):
|
||||
|
||||
|
@ -163,16 +163,21 @@ class PositionHolding:
|
||||
self.update_order(order)
|
||||
|
||||
def update_trade(self, trade: TradeData) -> None:
|
||||
""""""
|
||||
"""更新交易"""
|
||||
|
||||
if trade.direction == Direction.LONG:
|
||||
# 多,开仓 =》 增加今仓
|
||||
if trade.offset == Offset.OPEN:
|
||||
self.long_td += trade.volume
|
||||
# 多,平今 =》减少今仓
|
||||
elif trade.offset == Offset.CLOSETODAY:
|
||||
self.short_td -= trade.volume
|
||||
# 多,平昨 =》减少昨仓
|
||||
elif trade.offset == Offset.CLOSEYESTERDAY:
|
||||
self.short_yd -= trade.volume
|
||||
# 多,平仓 =》 减少
|
||||
elif trade.offset == Offset.CLOSE:
|
||||
if trade.exchange in [Exchange.SHFE, Exchange.INE]:
|
||||
if trade.exchange in [Exchange.SHFE, Exchange.INE] and self.short_yd >=trade.volume:
|
||||
self.short_yd -= trade.volume
|
||||
else:
|
||||
self.short_td -= trade.volume
|
||||
@ -191,7 +196,7 @@ class PositionHolding:
|
||||
elif trade.offset == Offset.CLOSEYESTERDAY:
|
||||
self.long_yd -= trade.volume
|
||||
elif trade.offset == Offset.CLOSE:
|
||||
if trade.exchange in [Exchange.SHFE, Exchange.INE]:
|
||||
if trade.exchange in [Exchange.SHFE, Exchange.INE] and self.long_yd >=trade.volume:
|
||||
self.long_yd -= trade.volume
|
||||
else:
|
||||
self.long_td -= trade.volume
|
||||
|
Loading…
Reference in New Issue
Block a user