[Add]新增多信号CTA策略开发功能,close #567
This commit is contained in:
parent
c66f1c32b4
commit
cbb7b1b510
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@ -11,13 +11,14 @@
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"from vnpy.trader.app.ctaStrategy.ctaBacktesting import BacktestingEngine, OptimizationSetting, MINUTE_DB_NAME\n",
|
||||
"from vnpy.trader.app.ctaStrategy.strategy.strategyAtrRsi import AtrRsiStrategy\n",
|
||||
"from vnpy.trader.app.ctaStrategy.strategy.strategyMultiTimeframe import MultiTimeframeStrategy"
|
||||
"#from vnpy.trader.app.ctaStrategy.strategy.strategyAtrRsi import AtrRsiStrategy\n",
|
||||
"#from vnpy.trader.app.ctaStrategy.strategy.strategyMultiTimeframe import MultiTimeframeStrategy\n",
|
||||
"from vnpy.trader.app.ctaStrategy.strategy.strategyMultiSignal import MultiSignalStrategy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
@ -29,7 +30,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@ -38,12 +39,12 @@
|
||||
"# 设置回测使用的数据\n",
|
||||
"engine.setBacktestingMode(engine.BAR_MODE) # 设置引擎的回测模式为K线\n",
|
||||
"engine.setDatabase(MINUTE_DB_NAME, 'IF0000') # 设置使用的历史数据库\n",
|
||||
"engine.setStartDate('20120101') # 设置回测用的数据起始日期"
|
||||
"engine.setStartDate('20100101') # 设置回测用的数据起始日期"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
@ -59,7 +60,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@ -68,16 +69,31 @@
|
||||
"# 在引擎中创建策略对象\n",
|
||||
"d = {'atrLength': 11} # 策略参数配置\n",
|
||||
"#engine.initStrategy(AtrRsiStrategy, d) # 创建策略对象\n",
|
||||
"engine.initStrategy(MultiTimeframeStrategy, d) # 创建策略对象"
|
||||
"#ngine.initStrategy(MultiTimeframeStrategy, d) \n",
|
||||
"engine.initStrategy(MultiSignalStrategy, {}) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2017-12-15 23:01:54.728000\t开始载入数据\n",
|
||||
"2017-12-15 23:01:54.765000\t载入完成,数据量:0\n",
|
||||
"2017-12-15 23:01:54.765000\t开始回测\n",
|
||||
"2017-12-15 23:01:54.765000\t策略初始化完成\n",
|
||||
"2017-12-15 23:01:54.765000\t策略启动完成\n",
|
||||
"2017-12-15 23:01:54.765000\t开始回放数据\n",
|
||||
"2017-12-15 23:01:54.766000\t数据回放结束\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 运行回测\n",
|
||||
"engine.runBacktesting() # 运行回测"
|
||||
|
File diff suppressed because one or more lines are too long
@ -601,3 +601,41 @@ class ArrayManager(object):
|
||||
if array:
|
||||
return up, down
|
||||
return up[-1], down[-1]
|
||||
|
||||
|
||||
########################################################################
|
||||
class CtaSignal(object):
|
||||
"""
|
||||
CTA策略信号,负责纯粹的信号生成(目标仓位),不参与具体交易管理
|
||||
"""
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self):
|
||||
"""Constructor"""
|
||||
self.signalPos = 0 # 信号仓位
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onBar(self, bar):
|
||||
"""K线推送"""
|
||||
pass
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onTick(self, tick):
|
||||
"""Tick推送"""
|
||||
pass
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def setSignalPos(self, pos):
|
||||
"""设置信号仓位"""
|
||||
self.signalPos = pos
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def getSignalPos(self):
|
||||
"""获取信号仓位"""
|
||||
return self.signalPos
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
251
vnpy/trader/app/ctaStrategy/strategy/strategyMultiSignal.py
Normal file
251
vnpy/trader/app/ctaStrategy/strategy/strategyMultiSignal.py
Normal file
@ -0,0 +1,251 @@
|
||||
# encoding: UTF-8
|
||||
|
||||
"""
|
||||
一个多信号组合策略,基于的信号包括:
|
||||
RSI(1分钟):大于70为多头、低于30为空头
|
||||
CCI(1分钟):大于10为多头、低于-10为空头
|
||||
MA(5分钟):快速大于慢速为多头、低于慢速为空头
|
||||
"""
|
||||
|
||||
from vnpy.trader.vtObject import VtBarData
|
||||
from vnpy.trader.vtConstant import EMPTY_STRING
|
||||
from vnpy.trader.app.ctaStrategy.ctaTemplate import (TargetPosTemplate,
|
||||
CtaSignal,
|
||||
BarGenerator,
|
||||
ArrayManager)
|
||||
|
||||
|
||||
########################################################################
|
||||
class RsiSignal(CtaSignal):
|
||||
"""RSI信号"""
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self):
|
||||
"""Constructor"""
|
||||
super(RsiSignal, self).__init__()
|
||||
|
||||
self.rsiWindow = 14
|
||||
self.rsiLevel = 20
|
||||
self.rsiLong = 50 + self.rsiLevel
|
||||
self.rsiShort = 50 - self.rsiLevel
|
||||
|
||||
self.bg = BarGenerator(self.onBar)
|
||||
self.am = ArrayManager()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onTick(self, tick):
|
||||
"""Tick更新"""
|
||||
self.bg.updateTick(tick)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onBar(self, bar):
|
||||
"""K线更新"""
|
||||
self.am.updateBar(bar)
|
||||
|
||||
if not self.am.inited:
|
||||
self.setSignalPos(0)
|
||||
|
||||
rsiValue = self.am.rsi(self.rsiWindow)
|
||||
|
||||
if rsiValue >= self.rsiLong:
|
||||
self.setSignalPos(1)
|
||||
elif rsiValue <= self.rsiShort:
|
||||
self.setSignalPos(-1)
|
||||
else:
|
||||
self.setSignalPos(0)
|
||||
|
||||
|
||||
########################################################################
|
||||
class CciSignal(CtaSignal):
|
||||
"""CCI信号"""
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self):
|
||||
"""Constructor"""
|
||||
super(CciSignal, self).__init__()
|
||||
|
||||
self.cciWindow = 30
|
||||
self.cciLevel = 10
|
||||
self.cciLong = self.cciLevel
|
||||
self.cciShort = -self.cciLevel
|
||||
|
||||
self.bg = BarGenerator(self.onBar)
|
||||
self.am = ArrayManager()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onTick(self, tick):
|
||||
"""Tick更新"""
|
||||
self.bg.updateTick(tick)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onBar(self, bar):
|
||||
"""K线更新"""
|
||||
self.am.updateBar(bar)
|
||||
|
||||
if not self.am.inited:
|
||||
self.setSignalPos(0)
|
||||
|
||||
cciValue = self.am.cci(self.cciWindow)
|
||||
|
||||
if cciValue >= self.cciLong:
|
||||
self.setSignalPos(1)
|
||||
elif cciValue<= self.cciShort:
|
||||
self.setSignalPos(-1)
|
||||
else:
|
||||
self.setSignalPos(0)
|
||||
|
||||
|
||||
########################################################################
|
||||
class MaSignal(CtaSignal):
|
||||
"""双均线信号"""
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self):
|
||||
"""Constructor"""
|
||||
super(MaSignal, self).__init__()
|
||||
|
||||
self.fastWindow = 5
|
||||
self.slowWindow = 20
|
||||
|
||||
self.bg = BarGenerator(self.onBar, 5, self.onFiveBar)
|
||||
self.am = ArrayManager()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onTick(self, tick):
|
||||
"""Tick更新"""
|
||||
self.bg.updateTick(tick)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onBar(self, bar):
|
||||
"""K线更新"""
|
||||
self.bg.updateBar(bar)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onFiveBar(self, bar):
|
||||
"""5分钟K线更新"""
|
||||
self.am.updateBar(bar)
|
||||
|
||||
if not self.am.inited:
|
||||
self.setSignalPos(0)
|
||||
|
||||
fastMa = self.am.sma(self.fastWindow)
|
||||
slowMa = self.am.sma(self.slowWindow)
|
||||
|
||||
if fastMa > slowMa:
|
||||
self.setSignalPos(1)
|
||||
elif fastMa < slowMa:
|
||||
self.setSignalPos(-1)
|
||||
else:
|
||||
self.setSignalPos(0)
|
||||
|
||||
|
||||
########################################################################
|
||||
class MultiSignalStrategy(TargetPosTemplate):
|
||||
"""跨时间周期交易策略"""
|
||||
className = 'MultiSignalStrategy'
|
||||
author = u'用Python的交易员'
|
||||
|
||||
# 策略参数
|
||||
initDays = 10 # 初始化数据所用的天数
|
||||
fixedSize = 1 # 每次交易的数量
|
||||
|
||||
# 策略变量
|
||||
signalPos = {} # 信号仓位
|
||||
|
||||
# 参数列表,保存了参数的名称
|
||||
paramList = ['name',
|
||||
'className',
|
||||
'author',
|
||||
'vtSymbol']
|
||||
|
||||
# 变量列表,保存了变量的名称
|
||||
varList = ['inited',
|
||||
'trading',
|
||||
'pos',
|
||||
'signalPos',
|
||||
'targetPos']
|
||||
|
||||
# 同步列表,保存了需要保存到数据库的变量名称
|
||||
syncList = ['pos']
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self, ctaEngine, setting):
|
||||
"""Constructor"""
|
||||
super(MultiSignalStrategy, self).__init__(ctaEngine, setting)
|
||||
|
||||
self.rsiSignal = RsiSignal()
|
||||
self.cciSignal = CciSignal()
|
||||
self.maSignal = MaSignal()
|
||||
|
||||
self.signalPos = {
|
||||
"rsi": 0,
|
||||
"cci": 0,
|
||||
"ma": 0
|
||||
}
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onInit(self):
|
||||
"""初始化策略(必须由用户继承实现)"""
|
||||
self.writeCtaLog(u'%s策略初始化' %self.name)
|
||||
|
||||
# 载入历史数据,并采用回放计算的方式初始化策略数值
|
||||
initData = self.loadBar(self.initDays)
|
||||
for bar in initData:
|
||||
self.onBar(bar)
|
||||
|
||||
self.putEvent()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onStart(self):
|
||||
"""启动策略(必须由用户继承实现)"""
|
||||
self.writeCtaLog(u'%s策略启动' %self.name)
|
||||
self.putEvent()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onStop(self):
|
||||
"""停止策略(必须由用户继承实现)"""
|
||||
self.writeCtaLog(u'%s策略停止' %self.name)
|
||||
self.putEvent()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onTick(self, tick):
|
||||
"""收到行情TICK推送(必须由用户继承实现)"""
|
||||
super(MultiSignalStrategy, self).onTick(tick)
|
||||
|
||||
self.rsiSignal.onTick(tick)
|
||||
self.cciSignal.onTick(tick)
|
||||
self.maSignal.onTick(tick)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onBar(self, bar):
|
||||
"""收到Bar推送(必须由用户继承实现)"""
|
||||
super(MultiSignalStrategy, self).onBar(bar)
|
||||
|
||||
self.rsiSignal.onBar(bar)
|
||||
self.cciSignal.onBar(bar)
|
||||
self.maSignal.onBar(bar)
|
||||
|
||||
self.signalPos['rsi'] = self.rsiSignal.getSignalPos()
|
||||
self.signalPos['cci'] = self.cciSignal.getSignalPos()
|
||||
self.signalPos['ma'] = self.maSignal.getSignalPos()
|
||||
|
||||
targetPos = 0
|
||||
for v in self.signalPos.values():
|
||||
targetPos += v
|
||||
|
||||
self.setTargetPos(targetPos)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onOrder(self, order):
|
||||
"""收到委托变化推送(必须由用户继承实现)"""
|
||||
super(MultiSignalStrategy, self).onOrder(order)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onTrade(self, trade):
|
||||
# 发出状态更新事件
|
||||
self.putEvent()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onStopOrder(self, so):
|
||||
"""停止单推送"""
|
||||
pass
|
Loading…
Reference in New Issue
Block a user