[Add]新增多信号CTA策略开发功能,close #567

This commit is contained in:
vn.py 2017-12-15 23:05:22 +08:00
parent c66f1c32b4
commit cbb7b1b510
4 changed files with 327 additions and 75 deletions

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {
"collapsed": false
},
@ -11,13 +11,14 @@
"%matplotlib inline\n",
"\n",
"from vnpy.trader.app.ctaStrategy.ctaBacktesting import BacktestingEngine, OptimizationSetting, MINUTE_DB_NAME\n",
"from vnpy.trader.app.ctaStrategy.strategy.strategyAtrRsi import AtrRsiStrategy\n",
"from vnpy.trader.app.ctaStrategy.strategy.strategyMultiTimeframe import MultiTimeframeStrategy"
"#from vnpy.trader.app.ctaStrategy.strategy.strategyAtrRsi import AtrRsiStrategy\n",
"#from vnpy.trader.app.ctaStrategy.strategy.strategyMultiTimeframe import MultiTimeframeStrategy\n",
"from vnpy.trader.app.ctaStrategy.strategy.strategyMultiSignal import MultiSignalStrategy"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {
"collapsed": true
},
@ -29,7 +30,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {
"collapsed": false
},
@ -38,12 +39,12 @@
"# 设置回测使用的数据\n",
"engine.setBacktestingMode(engine.BAR_MODE) # 设置引擎的回测模式为K线\n",
"engine.setDatabase(MINUTE_DB_NAME, 'IF0000') # 设置使用的历史数据库\n",
"engine.setStartDate('20120101') # 设置回测用的数据起始日期"
"engine.setStartDate('20100101') # 设置回测用的数据起始日期"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {
"collapsed": true
},
@ -59,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {
"collapsed": false
},
@ -68,16 +69,31 @@
"# 在引擎中创建策略对象\n",
"d = {'atrLength': 11} # 策略参数配置\n",
"#engine.initStrategy(AtrRsiStrategy, d) # 创建策略对象\n",
"engine.initStrategy(MultiTimeframeStrategy, d) # 创建策略对象"
"#ngine.initStrategy(MultiTimeframeStrategy, d) \n",
"engine.initStrategy(MultiSignalStrategy, {}) "
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2017-12-15 23:01:54.728000\t开始载入数据\n",
"2017-12-15 23:01:54.765000\t载入完成数据量0\n",
"2017-12-15 23:01:54.765000\t开始回测\n",
"2017-12-15 23:01:54.765000\t策略初始化完成\n",
"2017-12-15 23:01:54.765000\t策略启动完成\n",
"2017-12-15 23:01:54.765000\t开始回放数据\n",
"2017-12-15 23:01:54.766000\t数据回放结束\n"
]
}
],
"source": [
"# 运行回测\n",
"engine.runBacktesting() # 运行回测"

File diff suppressed because one or more lines are too long

View File

@ -601,3 +601,41 @@ class ArrayManager(object):
if array:
return up, down
return up[-1], down[-1]
########################################################################
class CtaSignal(object):
"""
CTA策略信号负责纯粹的信号生成目标仓位不参与具体交易管理
"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
self.signalPos = 0 # 信号仓位
#----------------------------------------------------------------------
def onBar(self, bar):
"""K线推送"""
pass
#----------------------------------------------------------------------
def onTick(self, tick):
"""Tick推送"""
pass
#----------------------------------------------------------------------
def setSignalPos(self, pos):
"""设置信号仓位"""
self.signalPos = pos
#----------------------------------------------------------------------
def getSignalPos(self):
"""获取信号仓位"""
return self.signalPos

View File

@ -0,0 +1,251 @@
# encoding: UTF-8
"""
一个多信号组合策略基于的信号包括
RSI1分钟大于70为多头低于30为空头
CCI1分钟大于10为多头低于-10为空头
MA5分钟快速大于慢速为多头低于慢速为空头
"""
from vnpy.trader.vtObject import VtBarData
from vnpy.trader.vtConstant import EMPTY_STRING
from vnpy.trader.app.ctaStrategy.ctaTemplate import (TargetPosTemplate,
CtaSignal,
BarGenerator,
ArrayManager)
########################################################################
class RsiSignal(CtaSignal):
"""RSI信号"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
super(RsiSignal, self).__init__()
self.rsiWindow = 14
self.rsiLevel = 20
self.rsiLong = 50 + self.rsiLevel
self.rsiShort = 50 - self.rsiLevel
self.bg = BarGenerator(self.onBar)
self.am = ArrayManager()
#----------------------------------------------------------------------
def onTick(self, tick):
"""Tick更新"""
self.bg.updateTick(tick)
#----------------------------------------------------------------------
def onBar(self, bar):
"""K线更新"""
self.am.updateBar(bar)
if not self.am.inited:
self.setSignalPos(0)
rsiValue = self.am.rsi(self.rsiWindow)
if rsiValue >= self.rsiLong:
self.setSignalPos(1)
elif rsiValue <= self.rsiShort:
self.setSignalPos(-1)
else:
self.setSignalPos(0)
########################################################################
class CciSignal(CtaSignal):
"""CCI信号"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
super(CciSignal, self).__init__()
self.cciWindow = 30
self.cciLevel = 10
self.cciLong = self.cciLevel
self.cciShort = -self.cciLevel
self.bg = BarGenerator(self.onBar)
self.am = ArrayManager()
#----------------------------------------------------------------------
def onTick(self, tick):
"""Tick更新"""
self.bg.updateTick(tick)
#----------------------------------------------------------------------
def onBar(self, bar):
"""K线更新"""
self.am.updateBar(bar)
if not self.am.inited:
self.setSignalPos(0)
cciValue = self.am.cci(self.cciWindow)
if cciValue >= self.cciLong:
self.setSignalPos(1)
elif cciValue<= self.cciShort:
self.setSignalPos(-1)
else:
self.setSignalPos(0)
########################################################################
class MaSignal(CtaSignal):
"""双均线信号"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
super(MaSignal, self).__init__()
self.fastWindow = 5
self.slowWindow = 20
self.bg = BarGenerator(self.onBar, 5, self.onFiveBar)
self.am = ArrayManager()
#----------------------------------------------------------------------
def onTick(self, tick):
"""Tick更新"""
self.bg.updateTick(tick)
#----------------------------------------------------------------------
def onBar(self, bar):
"""K线更新"""
self.bg.updateBar(bar)
#----------------------------------------------------------------------
def onFiveBar(self, bar):
"""5分钟K线更新"""
self.am.updateBar(bar)
if not self.am.inited:
self.setSignalPos(0)
fastMa = self.am.sma(self.fastWindow)
slowMa = self.am.sma(self.slowWindow)
if fastMa > slowMa:
self.setSignalPos(1)
elif fastMa < slowMa:
self.setSignalPos(-1)
else:
self.setSignalPos(0)
########################################################################
class MultiSignalStrategy(TargetPosTemplate):
"""跨时间周期交易策略"""
className = 'MultiSignalStrategy'
author = u'用Python的交易员'
# 策略参数
initDays = 10 # 初始化数据所用的天数
fixedSize = 1 # 每次交易的数量
# 策略变量
signalPos = {} # 信号仓位
# 参数列表,保存了参数的名称
paramList = ['name',
'className',
'author',
'vtSymbol']
# 变量列表,保存了变量的名称
varList = ['inited',
'trading',
'pos',
'signalPos',
'targetPos']
# 同步列表,保存了需要保存到数据库的变量名称
syncList = ['pos']
#----------------------------------------------------------------------
def __init__(self, ctaEngine, setting):
"""Constructor"""
super(MultiSignalStrategy, self).__init__(ctaEngine, setting)
self.rsiSignal = RsiSignal()
self.cciSignal = CciSignal()
self.maSignal = MaSignal()
self.signalPos = {
"rsi": 0,
"cci": 0,
"ma": 0
}
#----------------------------------------------------------------------
def onInit(self):
"""初始化策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略初始化' %self.name)
# 载入历史数据,并采用回放计算的方式初始化策略数值
initData = self.loadBar(self.initDays)
for bar in initData:
self.onBar(bar)
self.putEvent()
#----------------------------------------------------------------------
def onStart(self):
"""启动策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略启动' %self.name)
self.putEvent()
#----------------------------------------------------------------------
def onStop(self):
"""停止策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略停止' %self.name)
self.putEvent()
#----------------------------------------------------------------------
def onTick(self, tick):
"""收到行情TICK推送必须由用户继承实现"""
super(MultiSignalStrategy, self).onTick(tick)
self.rsiSignal.onTick(tick)
self.cciSignal.onTick(tick)
self.maSignal.onTick(tick)
#----------------------------------------------------------------------
def onBar(self, bar):
"""收到Bar推送必须由用户继承实现"""
super(MultiSignalStrategy, self).onBar(bar)
self.rsiSignal.onBar(bar)
self.cciSignal.onBar(bar)
self.maSignal.onBar(bar)
self.signalPos['rsi'] = self.rsiSignal.getSignalPos()
self.signalPos['cci'] = self.cciSignal.getSignalPos()
self.signalPos['ma'] = self.maSignal.getSignalPos()
targetPos = 0
for v in self.signalPos.values():
targetPos += v
self.setTargetPos(targetPos)
#----------------------------------------------------------------------
def onOrder(self, order):
"""收到委托变化推送(必须由用户继承实现)"""
super(MultiSignalStrategy, self).onOrder(order)
#----------------------------------------------------------------------
def onTrade(self, trade):
# 发出状态更新事件
self.putEvent()
#----------------------------------------------------------------------
def onStopOrder(self, so):
"""停止单推送"""
pass