[增强] 支持策略on_timer, 获取账号资金/风控,定时更新策略持仓

This commit is contained in:
msincenselee 2020-01-19 17:23:31 +08:00
parent e0d7dff896
commit 5208c85d77
5 changed files with 469 additions and 39 deletions

View File

@ -1,7 +1,7 @@
from pathlib import Path
from vnpy.trader.app import BaseApp
from vnpy.trader.constant import Direction
from vnpy.trader.constant import Direction,Offset
from vnpy.trader.object import TickData, BarData, TradeData, OrderData
from vnpy.trader.utility import BarGenerator, ArrayManager

View File

@ -26,11 +26,13 @@ from vnpy.trader.object import (
ContractData
)
from vnpy.trader.event import (
EVENT_TIMER,
EVENT_TICK,
EVENT_BAR,
EVENT_ORDER,
EVENT_TRADE,
EVENT_POSITION
EVENT_POSITION,
EVENT_STRATEGY_POS,
)
from vnpy.trader.constant import (
Direction,
@ -40,8 +42,13 @@ from vnpy.trader.constant import (
Offset,
Status
)
from vnpy.trader.utility import load_json, save_json, extract_vt_symbol, round_to, get_folder_path, \
get_underlying_symbol
from vnpy.trader.utility import (
load_json, save_json,
extract_vt_symbol,
round_to, get_folder_path,
get_underlying_symbol,
append_data)
from vnpy.trader.util_logger import setup_logger, logging
from vnpy.trader.converter import OffsetConverter
@ -58,6 +65,7 @@ from .base import (
)
from .template import CtaTemplate
from .cta_position import CtaPosition
STOP_STATUS_MAP = {
Status.SUBMITTING: StopOrderStatus.WAITING,
@ -80,16 +88,21 @@ class CtaEngine(BaseEngine):
6支持指定gateway的交易主引擎可接入多个gateway
"""
engine_type = EngineType.LIVE # live trading engine
engine_type = EngineType.LIVE # live trading engine
setting_filename = "cta_strategy_setting.json"
data_filename = "cta_strategy_data.json"
# 策略配置文件
setting_filename = "cta_strategy_pro_setting.json"
# 引擎配置文件
engine_filename = "cta_strategy_pro_config.json"
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super(CtaEngine, self).__init__(
main_engine, event_engine, APP_NAME)
self.engine_config = {}
self.strategy_setting = {} # strategy_name: dict
self.strategy_data = {} # strategy_name: dict
@ -97,6 +110,9 @@ class CtaEngine(BaseEngine):
self.class_module_map = {} # class_name: mudule_name
self.strategies = {} # strategy_name: strategy
# Strategy pos dict,key:strategy instance name, value: pos dict
self.strategy_pos_dict = {}
self.strategy_loggers = {} # strategy_name: logger
# 未能订阅的symbols,支持策略启动时并未接入gateway
@ -120,6 +136,8 @@ class CtaEngine(BaseEngine):
self.offset_converter = OffsetConverter(self.main_engine)
self.last_minute = None
def init_engine(self):
"""
"""
@ -128,43 +146,44 @@ class CtaEngine(BaseEngine):
self.register_event()
self.write_log("CTA策略引擎初始化成功")
def append_data(self, file_name: str, dict_data: dict, field_names: list = []):
"""
添加数据到csv文件中
:param file_name: csv的文件全路径
:param dict_data: OrderedDict
:return:
"""
dict_fieldnames = sorted(list(dict_data.keys())) if len(field_names) == 0 else field_names
try:
if not os.path.exists(file_name):
self.write_log(u'create csv file:{}'.format(file_name))
with open(file_name, 'a', encoding='utf8', newline='\n') as csvWriteFile:
writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel')
self.write_log(u'write csv header:{}'.format(dict_fieldnames))
writer.writeheader()
writer.writerow(dict_data)
else:
with open(file_name, 'a', encoding='utf8', newline='\n') as csvWriteFile:
writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel',
extrasaction='ignore')
writer.writerow(dict_data)
except Exception as ex:
self.write_error(u'append_data exception:{}'.format(str(ex)))
def close(self):
"""停止所属有的策略"""
self.stop_all_strategies()
def register_event(self):
"""注册事件"""
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
self.event_engine.register(EVENT_TICK, self.process_tick_event)
self.event_engine.register(EVENT_BAR, self.process_bar_event)
self.event_engine.register(EVENT_ORDER, self.process_order_event)
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
self.event_engine.register(EVENT_POSITION, self.process_position_event)
def register_funcs(self):
"""
register the funcs to main_engine
:return:
"""
self.main_engine.get_strategy_status = self.get_strategy_status
def process_timer_event(self, event: Event):
""" 处理定时器事件"""
# 触发每个策略的定时接口
for strategy in list(self.strategies.values()):
strategy.on_timer()
dt = datetime.now()
if self.last_minute != dt.minute:
self.last_minute = dt.minute
# 主动获取所有策略得持仓信息
all_strategy_pos = self.get_all_strategy_pos()
# 推送到事件
self.put_all_strategy_pos_event(all_strategy_pos)
def process_tick_event(self, event: Event):
"""处理tick到达事件"""
tick = event.data
@ -269,7 +288,7 @@ class CtaEngine(BaseEngine):
if strategy_name is not None:
trade_file = os.path.abspath(
os.path.join(get_folder_path('data'), '{}_trade.csv'.format(strategy_name)))
self.append_data(file_name=trade_file, dict_data=trade_dict)
append_data(file_name=trade_file, dict_data=trade_dict)
except Exception as ex:
self.write_error(u'写入交易记录csv出错{},{}'.format(str(ex), traceback.format_exc()))
@ -737,9 +756,22 @@ class CtaEngine(BaseEngine):
return None
def get_account(self, vt_accountid: str):
def get_account(self, vt_accountid: str = ""):
""" 查询账号的资金"""
return self.main_engine.get_account(vt_accountid)
# 如果启动风控,则使用风控中的最大仓位
if self.main_engine.rm_engine:
return self.main_engine.rm_engine.get_account(vt_accountid)
if len(vt_accountid) > 0:
account = self.main_engine.get_account(vt_accountid)
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01), 2), 100
else:
accounts = self.main_engine.get_all_accounts()
if len(accounts) > 0:
account = accounts[0]
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01), 2), 100
else:
return 0, 0, 0, 0
def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''):
""" 查询合约在账号的持仓,需要指定方向"""
@ -771,7 +803,7 @@ class CtaEngine(BaseEngine):
msg = f"触发异常已停止\n{traceback.format_exc()}"
self.write_log(msg=msg,
strategy_name=strategy.name,
strategy_name=strategy.strategy_name,
level=logging.CRITICAL)
def add_strategy(
@ -1062,6 +1094,160 @@ class CtaEngine(BaseEngine):
"""
return list(self.classes.keys())
def get_strategy_status(self, strategy_name):
"""
return strategy inited/trading status
:param strategy_name:
:return:
"""
inited = False
trading = False
strategy = self.strategies.get(strategy_name, None)
if strategy:
inited = strategy.inited
trading = strategy.trading
return inited, trading
def get_strategy_pos(self, name, strategy=None):
"""
获取策略的持仓字典
:param name:策略名
:return: [ {},{}]
"""
# 兼容处理如果strategy是None通过name获取
if strategy is None:
if name not in self.strategies:
self.write_log(u'getStategyPos 策略实例不存在:' + name)
return []
# 获取策略实例
strategy = self.strategies[name]
pos_list = []
if strategy.inited:
# 如果策略具有getPositions得方法则调用该方法
if hasattr(strategy, 'get_positions'):
pos_list = strategy.get_positions()
for pos in pos_list:
vt_symbol = pos.get('vt_symbol', None)
if vt_symbol:
symbol, exchange = extract_vt_symbol(vt_symbol)
pos.update({'symbol': symbol})
# 如果策略有 ctaPosition属性
elif hasattr(strategy, 'position') and issubclass(strategy.position, CtaPosition):
symbol, exchange = extract_vt_symbol(strategy.vt_symbol)
# 多仓
long_pos = {}
long_pos['vt_symbol'] = strategy.vt_symbol
long_pos['symbol'] = symbol
long_pos['direction'] = 'long'
long_pos['volume'] = strategy.position.long_pos
if long_pos['volume'] > 0:
pos_list.append(long_pos)
# 空仓
short_pos = {}
short_pos['vt_symbol'] = strategy.vt_symbol
short_pos['symbol'] = symbol
short_pos['direction'] = 'short'
short_pos['volume'] = abs(strategy.position.short_pos)
if short_pos['volume'] > 0:
pos_list.append(short_pos)
# 获取模板缺省pos属性
elif hasattr(strategy, 'pos') and isinstance(strategy.pos, int):
symbol, exchange = extract_vt_symbol(strategy.vt_symbol)
if strategy.pos > 0:
long_pos = {}
long_pos['vt_symbol'] = strategy.vt_symbol
long_pos['symbol'] = symbol
long_pos['direction'] = 'long'
long_pos['volume'] = strategy.pos
if long_pos['volume'] > 0:
pos_list.append(long_pos)
elif strategy.pos < 0:
short_pos = {}
short_pos['symbol'] = symbol
short_pos['vt_symbol'] = strategy.vt_symbol
short_pos['direction'] = 'short'
short_pos['volume'] = abs(strategy.pos)
if short_pos['volume'] > 0:
pos_list.append(short_pos)
# 新增处理SPD结尾得特殊自定义套利合约
try:
if strategy.vt_symbol.endswith('SPD') and len(pos_list) > 0:
old_pos_list = copy(pos_list)
pos_list = []
for pos in old_pos_list:
# SPD合约
spd_vt_symbol = pos.get('vt_symbol', None)
if spd_vt_symbol is not None and spd_vt_symbol.endswith('SPD'):
spd_symbol,spd_exchange = extract_vt_symbol(spd_vt_symbol)
spd_setting = self.main_engine.get_all_custom_contracts().get(spd_symbol, None)
if spd_setting is None:
self.write_error(u'获取不到:{}得设置信息,检查自定义合约配置文件'.format(spd_symbol))
pos_list.append(pos)
continue
leg1_direction = 'long' if pos.get('direction') in [Direction.LONG, 'long'] else 'short'
leg2_direction = 'short' if leg1_direction == 'long' else 'long'
spd_volume = pos.get('volume')
leg1_pos = {}
leg1_pos.update({'symbol': spd_setting.get('leg1_symbol')})
leg1_pos.update({'vt_symbol': spd_setting.get('leg1_symbol')})
leg1_pos.update({'direction': leg1_direction})
leg1_pos.update({'volume': spd_setting.get('leg1_ratio', 1)*spd_volume})
leg2_pos = {}
leg2_pos.update({'symbol': spd_setting.get('leg2_symbol')})
leg2_pos.update({'vt_symbol': spd_setting.get('leg2_symbol')})
leg2_pos.update({'direction': leg2_direction})
leg2_pos.update({'volume': spd_setting.get('leg2_ratio', 1) * spd_volume})
pos_list.append(leg1_pos)
pos_list.append(leg2_pos)
else:
pos_list.append(pos)
except Exception as ex:
self.write_error(u'分解SPD失败')
# update local pos dict
self.strategy_pos_dict.update({name: pos_list})
return pos_list
def get_all_strategy_pos(self):
"""
获取所有得策略仓位明细
"""
strategy_pos_list = []
for strategy_name in list(self.strategies.keys()):
d = OrderedDict()
d['accountid'] = self.engine_config.get('accountid', '-')
d['strategy_group'] = self.engine_config.get('strategy_group', '-')
d['strategy_name'] = strategy_name
dt = datetime.now()
d['date'] = dt.strftime('%Y%m%d')
d['hour'] = dt.hour
d['datetime'] = datetime.now()
try:
d['pos'] = self.get_strategy_pos(name=strategy_name)
except Exception as ex:
self.write_error(
u'get_strategy_pos exception:{},{}'.format(str(ex), traceback.format_exc()))
d['pos'] = []
strategy_pos_list.append(d)
return strategy_pos_list
def get_strategy_class_parameters(self, class_name: str):
"""
Get default parameters of a strategy class.
@ -1103,6 +1289,10 @@ class CtaEngine(BaseEngine):
"""
Load setting file.
"""
# 读取引擎得配置
self.engine_config = load_json(self.engine_filename)
# 读取策略得配置
self.strategy_setting = load_json(self.setting_filename)
for strategy_name, strategy_config in self.strategy_setting.items():
@ -1151,6 +1341,12 @@ class CtaEngine(BaseEngine):
event = Event(EVENT_CTA_STRATEGY, data)
self.event_engine.put(event)
def put_all_strategy_pos_event(self, strategy_pos_list: list = []):
"""推送所有策略得持仓事件"""
for strategy_pos in strategy_pos_list:
event = Event(EVENT_STRATEGY_POS, copy(strategy_pos))
self.event_engine.put(event)
def write_log(self, msg: str, strategy_name: str = '', level: int = logging.INFO):
"""
Create cta engine log event.

View File

@ -2,6 +2,7 @@ from vnpy.app.cta_strategy_pro import (
CtaTemplate,
StopOrder,
Direction,
Offset,
TickData,
BarData,
TradeData,
@ -15,6 +16,7 @@ class TurtleSignalStrategy(CtaTemplate):
""""""
author = "用Python的交易员"
x_minute = 15
entry_window = 20
exit_window = 10
atr_window = 20
@ -31,7 +33,7 @@ class TurtleSignalStrategy(CtaTemplate):
long_stop = 0
short_stop = 0
parameters = ["entry_window", "exit_window", "atr_window", "fixed_size"]
parameters = ["x_minuite", "entry_window", "exit_window", "atr_window", "fixed_size"]
variables = ["entry_up", "entry_down", "exit_up", "exit_down", "atr_value"]
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
@ -40,15 +42,17 @@ class TurtleSignalStrategy(CtaTemplate):
cta_engine, strategy_name, vt_symbol, setting
)
self.bg = BarGenerator(self.on_bar)
self.bg = BarGenerator(self.on_bar, window=self.x_minute)
self.am = ArrayManager()
self.cur_mi_price = None
def on_init(self):
"""
Callback when strategy is inited.
"""
self.write_log("策略初始化")
self.load_bar(20)
#self.load_bar(20)
def on_start(self):
"""
@ -74,6 +78,8 @@ class TurtleSignalStrategy(CtaTemplate):
"""
self.cancel_all()
self.cur_mi_price = bar.close_price
self.am.update_bar(bar)
if not self.am.inited:
return
@ -106,7 +112,7 @@ class TurtleSignalStrategy(CtaTemplate):
self.send_short_orders(self.entry_down)
cover_price = min(self.short_stop, self.exit_up)
self.cover(cover_price, abs(self.pos), True)
ret = self.cover(cover_price, abs(self.pos), True)
self.put_event()
@ -114,12 +120,17 @@ class TurtleSignalStrategy(CtaTemplate):
"""
Callback of new trade data update.
"""
pre_pos = self.pos
if trade.direction == Direction.LONG:
self.long_entry = trade.price
self.long_stop = self.long_entry - 2 * self.atr_value
self.pos += trade.volume
else:
self.short_entry = trade.price
self.short_stop = self.short_entry + 2 * self.atr_value
self.pos -= trade.volume
self.write_log(f'{self.vt_symbol},pos {pre_pos} => {self.pos}')
def on_order(self, order: OrderData):
"""
@ -135,6 +146,12 @@ class TurtleSignalStrategy(CtaTemplate):
def send_buy_orders(self, price):
""""""
if self.pos >= 4:
return
if self.cur_mi_price <= price - self.atr_value/2:
return
t = self.pos / self.fixed_size
if t < 1:
@ -151,6 +168,12 @@ class TurtleSignalStrategy(CtaTemplate):
def send_short_orders(self, price):
""""""
if self.pos <= -4:
return
if self.cur_mi_price >= price + self.atr_value / 2:
return
t = self.pos / self.fixed_size
if t > -1:

View File

@ -0,0 +1,208 @@
from vnpy.app.cta_strategy_pro import (
CtaTemplate,
StopOrder,
Direction,
Offset,
TickData,
BarData,
TradeData,
OrderData,
BarGenerator,
ArrayManager,
)
class TurtleSignalStrategy_v2(CtaTemplate):
""""""
author = "用Python的交易员"
x_minute = 15
entry_window = 20
exit_window = 10
atr_window = 20
fixed_size = 1
invest_pos = 1
invest_percent = 10 # 投资比例
entry_up = 0
entry_down = 0
exit_up = 0
exit_down = 0
atr_value = 0
long_entry = 0
short_entry = 0
long_stop = 0
short_stop = 0
parameters = ["x_minuite", "entry_window", "exit_window", "atr_window", "fixed_size"]
variables = ["entry_up", "entry_down", "exit_up", "exit_down", "atr_value"]
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
""""""
super(TurtleSignalStrategy_v2, self).__init__(
cta_engine, strategy_name, vt_symbol, setting
)
# 获取合约乘数,保证金比例
self.symbol_size = self.cta_engine.get_size(self.vt_symbol)
self.symbol_margin_rate = self.cta_engine.get_margin_rate(self.vt_symbol)
self.bg = BarGenerator(self.on_bar, window=self.x_minute)
self.am = ArrayManager()
self.cur_mi_price = None
def on_init(self):
"""
Callback when strategy is inited.
"""
self.write_log("策略初始化")
#self.load_bar(20)
def on_start(self):
"""
Callback when strategy is started.
"""
self.write_log("策略启动")
def on_stop(self):
"""
Callback when strategy is stopped.
"""
self.write_log("策略停止")
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
"""
self.bg.update_tick(tick)
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
self.cancel_all()
self.cur_mi_price = bar.close_price
self.am.update_bar(bar)
if not self.am.inited:
return
# Only calculates new entry channel when no position holding
if not self.pos:
self.entry_up, self.entry_down = self.am.donchian(
self.entry_window
)
self.exit_up, self.exit_down = self.am.donchian(self.exit_window)
if not self.pos:
self.atr_value = self.am.atr(self.atr_window)
self.long_entry = 0
self.short_entry = 0
self.long_stop = 0
self.short_stop = 0
self.send_buy_orders(self.entry_up)
self.send_short_orders(self.entry_down)
elif self.pos > 0:
self.send_buy_orders(self.entry_up)
sell_price = max(self.long_stop, self.exit_down)
self.sell(sell_price, abs(self.pos), True)
elif self.pos < 0:
self.send_short_orders(self.entry_down)
cover_price = min(self.short_stop, self.exit_up)
ret = self.cover(cover_price, abs(self.pos), True)
self.put_event()
def update_invest_pos(self):
"""计算获取投资仓位"""
# 获取账号资金
capital, available, cur_percent, percent_limit = self.cta_engine.get_account()
# 按照投资比例计算保证金
invest_margin = capital * self.invest_percent / 100
max_invest_pos = int(invest_margin / (self.cur_mi_price * self.symbol_size * self.symbol_margin_rate))
self.invest_pos = max(int(max_invest_pos / 4), 1)
def on_trade(self, trade: TradeData):
"""
Callback of new trade data update.
"""
pre_pos = self.pos
if trade.direction == Direction.LONG:
self.long_entry = trade.price
self.long_stop = self.long_entry - 2 * self.atr_value
self.pos += trade.volume
else:
self.short_entry = trade.price
self.short_stop = self.short_entry + 2 * self.atr_value
self.pos -= trade.volume
self.write_log(f'{self.vt_symbol},pos {pre_pos} => {self.pos}')
def on_order(self, order: OrderData):
"""
Callback of new order data update.
"""
pass
def on_stop_order(self, stop_order: StopOrder):
"""
Callback of stop order update.
"""
pass
def send_buy_orders(self, price):
""""""
if self.pos >= 4:
return
if self.cur_mi_price <= price - self.atr_value/2:
return
self.update_invest_pos()
t = self.pos / self.invest_pos
if t < 1:
self.buy(price, self.invest_pos, True)
if t < 2:
self.buy(price + self.atr_value * 0.5, self.invest_pos, True)
if t < 3:
self.buy(price + self.atr_value, self.invest_pos, True)
if t < 4:
self.buy(price + self.atr_value * 1.5, self.invest_pos, True)
def send_short_orders(self, price):
""""""
if self.pos <= -4:
return
if self.cur_mi_price >= price + self.atr_value / 2:
return
self.update_invest_pos()
t = self.pos / self.invest_pos
if t > -1:
self.short(price, self.invest_pos, True)
if t > -2:
self.short(price - self.atr_value * 0.5, self.invest_pos, True)
if t > -3:
self.short(price - self.atr_value, self.invest_pos, True)
if t > -4:
self.short(price - self.atr_value * 1.5, self.invest_pos, True)

View File

@ -119,6 +119,9 @@ class CtaTemplate(ABC):
}
return strategy_data
def on_timer(self):
pass
@virtual
def on_init(self):
"""