[增强] 支持策略on_timer, 获取账号资金/风控,定时更新策略持仓
This commit is contained in:
parent
e0d7dff896
commit
5208c85d77
@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
from vnpy.trader.app import BaseApp
|
||||
from vnpy.trader.constant import Direction
|
||||
from vnpy.trader.constant import Direction,Offset
|
||||
from vnpy.trader.object import TickData, BarData, TradeData, OrderData
|
||||
from vnpy.trader.utility import BarGenerator, ArrayManager
|
||||
|
||||
|
@ -26,11 +26,13 @@ from vnpy.trader.object import (
|
||||
ContractData
|
||||
)
|
||||
from vnpy.trader.event import (
|
||||
EVENT_TIMER,
|
||||
EVENT_TICK,
|
||||
EVENT_BAR,
|
||||
EVENT_ORDER,
|
||||
EVENT_TRADE,
|
||||
EVENT_POSITION
|
||||
EVENT_POSITION,
|
||||
EVENT_STRATEGY_POS,
|
||||
)
|
||||
from vnpy.trader.constant import (
|
||||
Direction,
|
||||
@ -40,8 +42,13 @@ from vnpy.trader.constant import (
|
||||
Offset,
|
||||
Status
|
||||
)
|
||||
from vnpy.trader.utility import load_json, save_json, extract_vt_symbol, round_to, get_folder_path, \
|
||||
get_underlying_symbol
|
||||
from vnpy.trader.utility import (
|
||||
load_json, save_json,
|
||||
extract_vt_symbol,
|
||||
round_to, get_folder_path,
|
||||
get_underlying_symbol,
|
||||
append_data)
|
||||
|
||||
from vnpy.trader.util_logger import setup_logger, logging
|
||||
from vnpy.trader.converter import OffsetConverter
|
||||
|
||||
@ -58,6 +65,7 @@ from .base import (
|
||||
|
||||
)
|
||||
from .template import CtaTemplate
|
||||
from .cta_position import CtaPosition
|
||||
|
||||
STOP_STATUS_MAP = {
|
||||
Status.SUBMITTING: StopOrderStatus.WAITING,
|
||||
@ -80,16 +88,21 @@ class CtaEngine(BaseEngine):
|
||||
6、支持指定gateway的交易。主引擎可接入多个gateway
|
||||
"""
|
||||
|
||||
engine_type = EngineType.LIVE # live trading engine
|
||||
engine_type = EngineType.LIVE # live trading engine
|
||||
|
||||
setting_filename = "cta_strategy_setting.json"
|
||||
data_filename = "cta_strategy_data.json"
|
||||
# 策略配置文件
|
||||
setting_filename = "cta_strategy_pro_setting.json"
|
||||
# 引擎配置文件
|
||||
engine_filename = "cta_strategy_pro_config.json"
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super(CtaEngine, self).__init__(
|
||||
main_engine, event_engine, APP_NAME)
|
||||
|
||||
self.engine_config = {}
|
||||
|
||||
self.strategy_setting = {} # strategy_name: dict
|
||||
self.strategy_data = {} # strategy_name: dict
|
||||
|
||||
@ -97,6 +110,9 @@ class CtaEngine(BaseEngine):
|
||||
self.class_module_map = {} # class_name: mudule_name
|
||||
self.strategies = {} # strategy_name: strategy
|
||||
|
||||
# Strategy pos dict,key:strategy instance name, value: pos dict
|
||||
self.strategy_pos_dict = {}
|
||||
|
||||
self.strategy_loggers = {} # strategy_name: logger
|
||||
|
||||
# 未能订阅的symbols,支持策略启动时,并未接入gateway
|
||||
@ -120,6 +136,8 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
self.offset_converter = OffsetConverter(self.main_engine)
|
||||
|
||||
self.last_minute = None
|
||||
|
||||
def init_engine(self):
|
||||
"""
|
||||
"""
|
||||
@ -128,43 +146,44 @@ class CtaEngine(BaseEngine):
|
||||
self.register_event()
|
||||
self.write_log("CTA策略引擎初始化成功")
|
||||
|
||||
def append_data(self, file_name: str, dict_data: dict, field_names: list = []):
|
||||
"""
|
||||
添加数据到csv文件中
|
||||
:param file_name: csv的文件全路径
|
||||
:param dict_data: OrderedDict
|
||||
:return:
|
||||
"""
|
||||
dict_fieldnames = sorted(list(dict_data.keys())) if len(field_names) == 0 else field_names
|
||||
|
||||
try:
|
||||
if not os.path.exists(file_name):
|
||||
self.write_log(u'create csv file:{}'.format(file_name))
|
||||
with open(file_name, 'a', encoding='utf8', newline='\n') as csvWriteFile:
|
||||
writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel')
|
||||
self.write_log(u'write csv header:{}'.format(dict_fieldnames))
|
||||
writer.writeheader()
|
||||
writer.writerow(dict_data)
|
||||
else:
|
||||
with open(file_name, 'a', encoding='utf8', newline='\n') as csvWriteFile:
|
||||
writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel',
|
||||
extrasaction='ignore')
|
||||
writer.writerow(dict_data)
|
||||
except Exception as ex:
|
||||
self.write_error(u'append_data exception:{}'.format(str(ex)))
|
||||
|
||||
def close(self):
|
||||
"""停止所属有的策略"""
|
||||
self.stop_all_strategies()
|
||||
|
||||
def register_event(self):
|
||||
"""注册事件"""
|
||||
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
|
||||
self.event_engine.register(EVENT_TICK, self.process_tick_event)
|
||||
self.event_engine.register(EVENT_BAR, self.process_bar_event)
|
||||
self.event_engine.register(EVENT_ORDER, self.process_order_event)
|
||||
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
|
||||
self.event_engine.register(EVENT_POSITION, self.process_position_event)
|
||||
|
||||
def register_funcs(self):
|
||||
"""
|
||||
register the funcs to main_engine
|
||||
:return:
|
||||
"""
|
||||
self.main_engine.get_strategy_status = self.get_strategy_status
|
||||
|
||||
def process_timer_event(self, event: Event):
|
||||
""" 处理定时器事件"""
|
||||
|
||||
# 触发每个策略的定时接口
|
||||
for strategy in list(self.strategies.values()):
|
||||
strategy.on_timer()
|
||||
|
||||
dt = datetime.now()
|
||||
|
||||
if self.last_minute != dt.minute:
|
||||
self.last_minute = dt.minute
|
||||
|
||||
# 主动获取所有策略得持仓信息
|
||||
all_strategy_pos = self.get_all_strategy_pos()
|
||||
|
||||
# 推送到事件
|
||||
self.put_all_strategy_pos_event(all_strategy_pos)
|
||||
|
||||
def process_tick_event(self, event: Event):
|
||||
"""处理tick到达事件"""
|
||||
tick = event.data
|
||||
@ -269,7 +288,7 @@ class CtaEngine(BaseEngine):
|
||||
if strategy_name is not None:
|
||||
trade_file = os.path.abspath(
|
||||
os.path.join(get_folder_path('data'), '{}_trade.csv'.format(strategy_name)))
|
||||
self.append_data(file_name=trade_file, dict_data=trade_dict)
|
||||
append_data(file_name=trade_file, dict_data=trade_dict)
|
||||
except Exception as ex:
|
||||
self.write_error(u'写入交易记录csv出错:{},{}'.format(str(ex), traceback.format_exc()))
|
||||
|
||||
@ -737,9 +756,22 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
return None
|
||||
|
||||
def get_account(self, vt_accountid: str):
|
||||
def get_account(self, vt_accountid: str = ""):
|
||||
""" 查询账号的资金"""
|
||||
return self.main_engine.get_account(vt_accountid)
|
||||
# 如果启动风控,则使用风控中的最大仓位
|
||||
if self.main_engine.rm_engine:
|
||||
return self.main_engine.rm_engine.get_account(vt_accountid)
|
||||
|
||||
if len(vt_accountid) > 0:
|
||||
account = self.main_engine.get_account(vt_accountid)
|
||||
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01), 2), 100
|
||||
else:
|
||||
accounts = self.main_engine.get_all_accounts()
|
||||
if len(accounts) > 0:
|
||||
account = accounts[0]
|
||||
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01), 2), 100
|
||||
else:
|
||||
return 0, 0, 0, 0
|
||||
|
||||
def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''):
|
||||
""" 查询合约在账号的持仓,需要指定方向"""
|
||||
@ -771,7 +803,7 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
msg = f"触发异常已停止\n{traceback.format_exc()}"
|
||||
self.write_log(msg=msg,
|
||||
strategy_name=strategy.name,
|
||||
strategy_name=strategy.strategy_name,
|
||||
level=logging.CRITICAL)
|
||||
|
||||
def add_strategy(
|
||||
@ -1062,6 +1094,160 @@ class CtaEngine(BaseEngine):
|
||||
"""
|
||||
return list(self.classes.keys())
|
||||
|
||||
def get_strategy_status(self, strategy_name):
|
||||
"""
|
||||
return strategy inited/trading status
|
||||
:param strategy_name:
|
||||
:return:
|
||||
"""
|
||||
inited = False
|
||||
trading = False
|
||||
|
||||
strategy = self.strategies.get(strategy_name, None)
|
||||
if strategy:
|
||||
inited = strategy.inited
|
||||
trading = strategy.trading
|
||||
|
||||
return inited, trading
|
||||
|
||||
def get_strategy_pos(self, name, strategy=None):
|
||||
"""
|
||||
获取策略的持仓字典
|
||||
:param name:策略名
|
||||
:return: [ {},{}]
|
||||
"""
|
||||
# 兼容处理,如果strategy是None,通过name获取
|
||||
if strategy is None:
|
||||
if name not in self.strategies:
|
||||
self.write_log(u'getStategyPos 策略实例不存在:' + name)
|
||||
return []
|
||||
# 获取策略实例
|
||||
strategy = self.strategies[name]
|
||||
|
||||
pos_list = []
|
||||
|
||||
if strategy.inited:
|
||||
# 如果策略具有getPositions得方法,则调用该方法
|
||||
if hasattr(strategy, 'get_positions'):
|
||||
pos_list = strategy.get_positions()
|
||||
for pos in pos_list:
|
||||
vt_symbol = pos.get('vt_symbol', None)
|
||||
if vt_symbol:
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
pos.update({'symbol': symbol})
|
||||
|
||||
# 如果策略有 ctaPosition属性
|
||||
elif hasattr(strategy, 'position') and issubclass(strategy.position, CtaPosition):
|
||||
symbol, exchange = extract_vt_symbol(strategy.vt_symbol)
|
||||
# 多仓
|
||||
long_pos = {}
|
||||
long_pos['vt_symbol'] = strategy.vt_symbol
|
||||
long_pos['symbol'] = symbol
|
||||
long_pos['direction'] = 'long'
|
||||
long_pos['volume'] = strategy.position.long_pos
|
||||
if long_pos['volume'] > 0:
|
||||
pos_list.append(long_pos)
|
||||
|
||||
# 空仓
|
||||
short_pos = {}
|
||||
short_pos['vt_symbol'] = strategy.vt_symbol
|
||||
short_pos['symbol'] = symbol
|
||||
short_pos['direction'] = 'short'
|
||||
short_pos['volume'] = abs(strategy.position.short_pos)
|
||||
if short_pos['volume'] > 0:
|
||||
pos_list.append(short_pos)
|
||||
|
||||
# 获取模板缺省pos属性
|
||||
elif hasattr(strategy, 'pos') and isinstance(strategy.pos, int):
|
||||
symbol, exchange = extract_vt_symbol(strategy.vt_symbol)
|
||||
if strategy.pos > 0:
|
||||
long_pos = {}
|
||||
long_pos['vt_symbol'] = strategy.vt_symbol
|
||||
long_pos['symbol'] = symbol
|
||||
long_pos['direction'] = 'long'
|
||||
long_pos['volume'] = strategy.pos
|
||||
if long_pos['volume'] > 0:
|
||||
pos_list.append(long_pos)
|
||||
elif strategy.pos < 0:
|
||||
short_pos = {}
|
||||
short_pos['symbol'] = symbol
|
||||
short_pos['vt_symbol'] = strategy.vt_symbol
|
||||
short_pos['direction'] = 'short'
|
||||
short_pos['volume'] = abs(strategy.pos)
|
||||
if short_pos['volume'] > 0:
|
||||
pos_list.append(short_pos)
|
||||
|
||||
# 新增处理SPD结尾得特殊自定义套利合约
|
||||
try:
|
||||
if strategy.vt_symbol.endswith('SPD') and len(pos_list) > 0:
|
||||
old_pos_list = copy(pos_list)
|
||||
pos_list = []
|
||||
for pos in old_pos_list:
|
||||
# SPD合约
|
||||
spd_vt_symbol = pos.get('vt_symbol', None)
|
||||
if spd_vt_symbol is not None and spd_vt_symbol.endswith('SPD'):
|
||||
spd_symbol,spd_exchange = extract_vt_symbol(spd_vt_symbol)
|
||||
spd_setting = self.main_engine.get_all_custom_contracts().get(spd_symbol, None)
|
||||
|
||||
if spd_setting is None:
|
||||
self.write_error(u'获取不到:{}得设置信息,检查自定义合约配置文件'.format(spd_symbol))
|
||||
pos_list.append(pos)
|
||||
continue
|
||||
|
||||
leg1_direction = 'long' if pos.get('direction') in [Direction.LONG, 'long'] else 'short'
|
||||
leg2_direction = 'short' if leg1_direction == 'long' else 'long'
|
||||
spd_volume = pos.get('volume')
|
||||
|
||||
leg1_pos = {}
|
||||
leg1_pos.update({'symbol': spd_setting.get('leg1_symbol')})
|
||||
leg1_pos.update({'vt_symbol': spd_setting.get('leg1_symbol')})
|
||||
leg1_pos.update({'direction': leg1_direction})
|
||||
leg1_pos.update({'volume': spd_setting.get('leg1_ratio', 1)*spd_volume})
|
||||
|
||||
leg2_pos = {}
|
||||
leg2_pos.update({'symbol': spd_setting.get('leg2_symbol')})
|
||||
leg2_pos.update({'vt_symbol': spd_setting.get('leg2_symbol')})
|
||||
leg2_pos.update({'direction': leg2_direction})
|
||||
leg2_pos.update({'volume': spd_setting.get('leg2_ratio', 1) * spd_volume})
|
||||
|
||||
pos_list.append(leg1_pos)
|
||||
pos_list.append(leg2_pos)
|
||||
|
||||
else:
|
||||
pos_list.append(pos)
|
||||
|
||||
except Exception as ex:
|
||||
self.write_error(u'分解SPD失败')
|
||||
|
||||
# update local pos dict
|
||||
self.strategy_pos_dict.update({name: pos_list})
|
||||
|
||||
return pos_list
|
||||
|
||||
def get_all_strategy_pos(self):
|
||||
"""
|
||||
获取所有得策略仓位明细
|
||||
"""
|
||||
strategy_pos_list = []
|
||||
for strategy_name in list(self.strategies.keys()):
|
||||
d = OrderedDict()
|
||||
d['accountid'] = self.engine_config.get('accountid', '-')
|
||||
d['strategy_group'] = self.engine_config.get('strategy_group', '-')
|
||||
d['strategy_name'] = strategy_name
|
||||
dt = datetime.now()
|
||||
d['date'] = dt.strftime('%Y%m%d')
|
||||
d['hour'] = dt.hour
|
||||
d['datetime'] = datetime.now()
|
||||
try:
|
||||
d['pos'] = self.get_strategy_pos(name=strategy_name)
|
||||
except Exception as ex:
|
||||
self.write_error(
|
||||
u'get_strategy_pos exception:{},{}'.format(str(ex), traceback.format_exc()))
|
||||
d['pos'] = []
|
||||
strategy_pos_list.append(d)
|
||||
|
||||
return strategy_pos_list
|
||||
|
||||
def get_strategy_class_parameters(self, class_name: str):
|
||||
"""
|
||||
Get default parameters of a strategy class.
|
||||
@ -1103,6 +1289,10 @@ class CtaEngine(BaseEngine):
|
||||
"""
|
||||
Load setting file.
|
||||
"""
|
||||
# 读取引擎得配置
|
||||
self.engine_config = load_json(self.engine_filename)
|
||||
|
||||
# 读取策略得配置
|
||||
self.strategy_setting = load_json(self.setting_filename)
|
||||
|
||||
for strategy_name, strategy_config in self.strategy_setting.items():
|
||||
@ -1151,6 +1341,12 @@ class CtaEngine(BaseEngine):
|
||||
event = Event(EVENT_CTA_STRATEGY, data)
|
||||
self.event_engine.put(event)
|
||||
|
||||
def put_all_strategy_pos_event(self, strategy_pos_list: list = []):
|
||||
"""推送所有策略得持仓事件"""
|
||||
for strategy_pos in strategy_pos_list:
|
||||
event = Event(EVENT_STRATEGY_POS, copy(strategy_pos))
|
||||
self.event_engine.put(event)
|
||||
|
||||
def write_log(self, msg: str, strategy_name: str = '', level: int = logging.INFO):
|
||||
"""
|
||||
Create cta engine log event.
|
||||
|
@ -2,6 +2,7 @@ from vnpy.app.cta_strategy_pro import (
|
||||
CtaTemplate,
|
||||
StopOrder,
|
||||
Direction,
|
||||
Offset,
|
||||
TickData,
|
||||
BarData,
|
||||
TradeData,
|
||||
@ -15,6 +16,7 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
""""""
|
||||
author = "用Python的交易员"
|
||||
|
||||
x_minute = 15
|
||||
entry_window = 20
|
||||
exit_window = 10
|
||||
atr_window = 20
|
||||
@ -31,7 +33,7 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
long_stop = 0
|
||||
short_stop = 0
|
||||
|
||||
parameters = ["entry_window", "exit_window", "atr_window", "fixed_size"]
|
||||
parameters = ["x_minuite", "entry_window", "exit_window", "atr_window", "fixed_size"]
|
||||
variables = ["entry_up", "entry_down", "exit_up", "exit_down", "atr_value"]
|
||||
|
||||
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
|
||||
@ -40,15 +42,17 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
cta_engine, strategy_name, vt_symbol, setting
|
||||
)
|
||||
|
||||
self.bg = BarGenerator(self.on_bar)
|
||||
self.bg = BarGenerator(self.on_bar, window=self.x_minute)
|
||||
self.am = ArrayManager()
|
||||
|
||||
self.cur_mi_price = None
|
||||
|
||||
def on_init(self):
|
||||
"""
|
||||
Callback when strategy is inited.
|
||||
"""
|
||||
self.write_log("策略初始化")
|
||||
self.load_bar(20)
|
||||
#self.load_bar(20)
|
||||
|
||||
def on_start(self):
|
||||
"""
|
||||
@ -74,6 +78,8 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
"""
|
||||
self.cancel_all()
|
||||
|
||||
self.cur_mi_price = bar.close_price
|
||||
|
||||
self.am.update_bar(bar)
|
||||
if not self.am.inited:
|
||||
return
|
||||
@ -106,7 +112,7 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
self.send_short_orders(self.entry_down)
|
||||
|
||||
cover_price = min(self.short_stop, self.exit_up)
|
||||
self.cover(cover_price, abs(self.pos), True)
|
||||
ret = self.cover(cover_price, abs(self.pos), True)
|
||||
|
||||
self.put_event()
|
||||
|
||||
@ -114,12 +120,17 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
"""
|
||||
Callback of new trade data update.
|
||||
"""
|
||||
pre_pos = self.pos
|
||||
if trade.direction == Direction.LONG:
|
||||
self.long_entry = trade.price
|
||||
self.long_stop = self.long_entry - 2 * self.atr_value
|
||||
self.pos += trade.volume
|
||||
else:
|
||||
self.short_entry = trade.price
|
||||
self.short_stop = self.short_entry + 2 * self.atr_value
|
||||
self.pos -= trade.volume
|
||||
|
||||
self.write_log(f'{self.vt_symbol},pos {pre_pos} => {self.pos}')
|
||||
|
||||
def on_order(self, order: OrderData):
|
||||
"""
|
||||
@ -135,6 +146,12 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
|
||||
def send_buy_orders(self, price):
|
||||
""""""
|
||||
if self.pos >= 4:
|
||||
return
|
||||
|
||||
if self.cur_mi_price <= price - self.atr_value/2:
|
||||
return
|
||||
|
||||
t = self.pos / self.fixed_size
|
||||
|
||||
if t < 1:
|
||||
@ -151,6 +168,12 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
|
||||
def send_short_orders(self, price):
|
||||
""""""
|
||||
if self.pos <= -4:
|
||||
return
|
||||
|
||||
if self.cur_mi_price >= price + self.atr_value / 2:
|
||||
return
|
||||
|
||||
t = self.pos / self.fixed_size
|
||||
|
||||
if t > -1:
|
||||
|
@ -0,0 +1,208 @@
|
||||
from vnpy.app.cta_strategy_pro import (
|
||||
CtaTemplate,
|
||||
StopOrder,
|
||||
Direction,
|
||||
Offset,
|
||||
TickData,
|
||||
BarData,
|
||||
TradeData,
|
||||
OrderData,
|
||||
BarGenerator,
|
||||
ArrayManager,
|
||||
)
|
||||
|
||||
|
||||
class TurtleSignalStrategy_v2(CtaTemplate):
|
||||
""""""
|
||||
author = "用Python的交易员"
|
||||
|
||||
x_minute = 15
|
||||
entry_window = 20
|
||||
exit_window = 10
|
||||
atr_window = 20
|
||||
fixed_size = 1
|
||||
invest_pos = 1
|
||||
invest_percent = 10 # 投资比例
|
||||
|
||||
entry_up = 0
|
||||
entry_down = 0
|
||||
exit_up = 0
|
||||
exit_down = 0
|
||||
atr_value = 0
|
||||
|
||||
long_entry = 0
|
||||
short_entry = 0
|
||||
long_stop = 0
|
||||
short_stop = 0
|
||||
|
||||
parameters = ["x_minuite", "entry_window", "exit_window", "atr_window", "fixed_size"]
|
||||
variables = ["entry_up", "entry_down", "exit_up", "exit_down", "atr_value"]
|
||||
|
||||
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
|
||||
""""""
|
||||
super(TurtleSignalStrategy_v2, self).__init__(
|
||||
cta_engine, strategy_name, vt_symbol, setting
|
||||
)
|
||||
|
||||
# 获取合约乘数,保证金比例
|
||||
self.symbol_size = self.cta_engine.get_size(self.vt_symbol)
|
||||
self.symbol_margin_rate = self.cta_engine.get_margin_rate(self.vt_symbol)
|
||||
|
||||
self.bg = BarGenerator(self.on_bar, window=self.x_minute)
|
||||
self.am = ArrayManager()
|
||||
|
||||
self.cur_mi_price = None
|
||||
|
||||
def on_init(self):
|
||||
"""
|
||||
Callback when strategy is inited.
|
||||
"""
|
||||
self.write_log("策略初始化")
|
||||
#self.load_bar(20)
|
||||
|
||||
def on_start(self):
|
||||
"""
|
||||
Callback when strategy is started.
|
||||
"""
|
||||
self.write_log("策略启动")
|
||||
|
||||
def on_stop(self):
|
||||
"""
|
||||
Callback when strategy is stopped.
|
||||
"""
|
||||
self.write_log("策略停止")
|
||||
|
||||
def on_tick(self, tick: TickData):
|
||||
"""
|
||||
Callback of new tick data update.
|
||||
"""
|
||||
self.bg.update_tick(tick)
|
||||
|
||||
def on_bar(self, bar: BarData):
|
||||
"""
|
||||
Callback of new bar data update.
|
||||
"""
|
||||
self.cancel_all()
|
||||
|
||||
self.cur_mi_price = bar.close_price
|
||||
|
||||
self.am.update_bar(bar)
|
||||
if not self.am.inited:
|
||||
return
|
||||
|
||||
# Only calculates new entry channel when no position holding
|
||||
if not self.pos:
|
||||
self.entry_up, self.entry_down = self.am.donchian(
|
||||
self.entry_window
|
||||
)
|
||||
|
||||
self.exit_up, self.exit_down = self.am.donchian(self.exit_window)
|
||||
|
||||
if not self.pos:
|
||||
self.atr_value = self.am.atr(self.atr_window)
|
||||
|
||||
self.long_entry = 0
|
||||
self.short_entry = 0
|
||||
self.long_stop = 0
|
||||
self.short_stop = 0
|
||||
|
||||
self.send_buy_orders(self.entry_up)
|
||||
self.send_short_orders(self.entry_down)
|
||||
elif self.pos > 0:
|
||||
self.send_buy_orders(self.entry_up)
|
||||
|
||||
sell_price = max(self.long_stop, self.exit_down)
|
||||
self.sell(sell_price, abs(self.pos), True)
|
||||
|
||||
elif self.pos < 0:
|
||||
self.send_short_orders(self.entry_down)
|
||||
|
||||
cover_price = min(self.short_stop, self.exit_up)
|
||||
ret = self.cover(cover_price, abs(self.pos), True)
|
||||
|
||||
self.put_event()
|
||||
|
||||
def update_invest_pos(self):
|
||||
"""计算获取投资仓位"""
|
||||
# 获取账号资金
|
||||
capital, available, cur_percent, percent_limit = self.cta_engine.get_account()
|
||||
# 按照投资比例计算保证金
|
||||
invest_margin = capital * self.invest_percent / 100
|
||||
max_invest_pos = int(invest_margin / (self.cur_mi_price * self.symbol_size * self.symbol_margin_rate))
|
||||
self.invest_pos = max(int(max_invest_pos / 4), 1)
|
||||
|
||||
def on_trade(self, trade: TradeData):
|
||||
"""
|
||||
Callback of new trade data update.
|
||||
"""
|
||||
pre_pos = self.pos
|
||||
if trade.direction == Direction.LONG:
|
||||
self.long_entry = trade.price
|
||||
self.long_stop = self.long_entry - 2 * self.atr_value
|
||||
self.pos += trade.volume
|
||||
else:
|
||||
self.short_entry = trade.price
|
||||
self.short_stop = self.short_entry + 2 * self.atr_value
|
||||
self.pos -= trade.volume
|
||||
|
||||
self.write_log(f'{self.vt_symbol},pos {pre_pos} => {self.pos}')
|
||||
|
||||
def on_order(self, order: OrderData):
|
||||
"""
|
||||
Callback of new order data update.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_stop_order(self, stop_order: StopOrder):
|
||||
"""
|
||||
Callback of stop order update.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send_buy_orders(self, price):
|
||||
""""""
|
||||
if self.pos >= 4:
|
||||
return
|
||||
|
||||
if self.cur_mi_price <= price - self.atr_value/2:
|
||||
return
|
||||
|
||||
self.update_invest_pos()
|
||||
|
||||
t = self.pos / self.invest_pos
|
||||
|
||||
if t < 1:
|
||||
self.buy(price, self.invest_pos, True)
|
||||
|
||||
if t < 2:
|
||||
self.buy(price + self.atr_value * 0.5, self.invest_pos, True)
|
||||
|
||||
if t < 3:
|
||||
self.buy(price + self.atr_value, self.invest_pos, True)
|
||||
|
||||
if t < 4:
|
||||
self.buy(price + self.atr_value * 1.5, self.invest_pos, True)
|
||||
|
||||
def send_short_orders(self, price):
|
||||
""""""
|
||||
if self.pos <= -4:
|
||||
return
|
||||
|
||||
if self.cur_mi_price >= price + self.atr_value / 2:
|
||||
return
|
||||
|
||||
self.update_invest_pos()
|
||||
|
||||
t = self.pos / self.invest_pos
|
||||
|
||||
if t > -1:
|
||||
self.short(price, self.invest_pos, True)
|
||||
|
||||
if t > -2:
|
||||
self.short(price - self.atr_value * 0.5, self.invest_pos, True)
|
||||
|
||||
if t > -3:
|
||||
self.short(price - self.atr_value, self.invest_pos, True)
|
||||
|
||||
if t > -4:
|
||||
self.short(price - self.atr_value * 1.5, self.invest_pos, True)
|
@ -119,6 +119,9 @@ class CtaTemplate(ABC):
|
||||
}
|
||||
return strategy_data
|
||||
|
||||
def on_timer(self):
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_init(self):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user