vnpy/vn.trader/ctaAlgo/strategyAtrRsi.py

290 lines
11 KiB
Python
Raw Normal View History

2016-05-23 11:20:27 +00:00
# encoding: UTF-8
"""
一个ATR-RSI指标结合的交易策略适合用在股指的1分钟和5分钟线上
注意事项
1. 作者不对交易盈利做任何保证策略代码仅供参考
2. 本策略需要用到talib没有安装的用户请先参考www.vnpy.org上的教程安装
3. 将IF0000_1min.csv用ctaHistoryData.py导入MongoDB后直接运行本文件即可回测策略
"""
from ctaBase import *
from ctaTemplate import CtaTemplate
import talib
import numpy as np
########################################################################
class AtrRsiStrategy(CtaTemplate):
"""结合ATR和RSI指标的一个分钟线交易策略"""
className = 'AtrRsiStrategy'
author = u'用Python的交易员'
# 策略参数
atrLength = 22 # 计算ATR指标的窗口数
atrMaLength = 10 # 计算ATR均线的窗口数
rsiLength = 5 # 计算RSI的窗口数
rsiEntry = 16 # RSI的开仓信号
trailingPercent = 0.8 # 百分比移动止损
initDays = 10 # 初始化数据所用的天数
# 策略变量
bar = None # K线对象
barMinute = EMPTY_STRING # K线当前的分钟
bufferSize = 100 # 需要缓存的数据的大小
bufferCount = 0 # 目前已经缓存了的数据的计数
highArray = np.zeros(bufferSize) # K线最高价的数组
lowArray = np.zeros(bufferSize) # K线最低价的数组
closeArray = np.zeros(bufferSize) # K线收盘价的数组
atrCount = 0 # 目前已经缓存了的ATR的计数
atrArray = np.zeros(bufferSize) # ATR指标的数组
atrValue = 0 # 最新的ATR指标数值
atrMa = 0 # ATR移动平均的数值
rsiValue = 0 # RSI指标的数值
rsiBuy = 0 # RSI买开阈值
rsiSell = 0 # RSI卖开阈值
intraTradeHigh = 0 # 移动止损用的持仓期内最高价
intraTradeLow = 0 # 移动止损用的持仓期内最低价
orderList = [] # 保存委托代码的列表
# 参数列表,保存了参数的名称
paramList = ['name',
'className',
'author',
'vtSymbol',
'atrLength',
'atrMaLength',
'rsiLength',
'rsiEntry',
'trailingPercent']
# 变量列表,保存了变量的名称
varList = ['inited',
'trading',
'pos',
'atrValue',
'atrMa',
'rsiValue',
'rsiBuy',
'rsiSell']
#----------------------------------------------------------------------
def __init__(self, ctaEngine, setting):
"""Constructor"""
super(AtrRsiStrategy, self).__init__(ctaEngine, setting)
# 注意策略类中的可变对象属性通常是list和dict等在策略初始化时需要重新创建
# 否则会出现多个策略实例之间数据共享的情况,有可能导致潜在的策略逻辑错误风险,
# 策略类中的这些可变对象属性可以选择不写全都放在__init__下面写主要是为了阅读
# 策略时方便(更多是个编程习惯的选择)
2016-05-23 11:20:27 +00:00
#----------------------------------------------------------------------
def onInit(self):
"""初始化策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略初始化' %self.name)
# 初始化RSI入场阈值
self.rsiBuy = 50 + self.rsiEntry
self.rsiSell = 50 - self.rsiEntry
# 载入历史数据,并采用回放计算的方式初始化策略数值
initData = self.loadBar(self.initDays)
for bar in initData:
self.onBar(bar)
self.putEvent()
#----------------------------------------------------------------------
def onStart(self):
"""启动策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略启动' %self.name)
self.putEvent()
#----------------------------------------------------------------------
def onStop(self):
"""停止策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略停止' %self.name)
self.putEvent()
#----------------------------------------------------------------------
def onTick(self, tick):
"""收到行情TICK推送必须由用户继承实现"""
# 计算K线
tickMinute = tick.datetime.minute
if tickMinute != self.barMinute:
if self.bar:
self.onBar(self.bar)
bar = CtaBarData()
bar.vtSymbol = tick.vtSymbol
bar.symbol = tick.symbol
bar.exchange = tick.exchange
bar.open = tick.lastPrice
bar.high = tick.lastPrice
bar.low = tick.lastPrice
bar.close = tick.lastPrice
bar.date = tick.date
bar.time = tick.time
bar.datetime = tick.datetime # K线的时间设为第一个Tick的时间
self.bar = bar # 这种写法为了减少一层访问,加快速度
self.barMinute = tickMinute # 更新当前的分钟
else: # 否则继续累加新的K线
bar = self.bar # 写法同样为了加快速度
bar.high = max(bar.high, tick.lastPrice)
bar.low = min(bar.low, tick.lastPrice)
bar.close = tick.lastPrice
#----------------------------------------------------------------------
def onBar(self, bar):
"""收到Bar推送必须由用户继承实现"""
# 撤销之前发出的尚未成交的委托(包括限价单和停止单)
for orderID in self.orderList:
self.cancelOrder(orderID)
self.orderList = []
# 保存K线数据
self.closeArray[0:self.bufferSize-1] = self.closeArray[1:self.bufferSize]
self.highArray[0:self.bufferSize-1] = self.highArray[1:self.bufferSize]
self.lowArray[0:self.bufferSize-1] = self.lowArray[1:self.bufferSize]
self.closeArray[-1] = bar.close
self.highArray[-1] = bar.high
self.lowArray[-1] = bar.low
self.bufferCount += 1
if self.bufferCount < self.bufferSize:
return
# 计算指标数值
self.atrValue = talib.ATR(self.highArray,
self.lowArray,
self.closeArray,
self.atrLength)[-1]
self.atrArray[0:self.bufferSize-1] = self.atrArray[1:self.bufferSize]
self.atrArray[-1] = self.atrValue
self.atrCount += 1
if self.atrCount < self.bufferSize:
return
self.atrMa = talib.MA(self.atrArray,
self.atrMaLength)[-1]
self.rsiValue = talib.RSI(self.closeArray,
self.rsiLength)[-1]
# 判断是否要进行交易
# 当前无仓位
if self.pos == 0:
self.intraTradeHigh = bar.high
self.intraTradeLow = bar.low
# ATR数值上穿其移动平均线说明行情短期内波动加大
# 即处于趋势的概率较大适合CTA开仓
if self.atrValue > self.atrMa:
# 使用RSI指标的趋势行情时会在超买超卖区钝化特征作为开仓信号
if self.rsiValue > self.rsiBuy:
# 这里为了保证成交选择超价5个整指数点下单
self.buy(bar.close+5, 1)
elif self.rsiValue < self.rsiSell:
2016-05-23 11:20:27 +00:00
self.short(bar.close-5, 1)
# 持有多头仓位
elif self.pos > 0:
2016-05-23 11:20:27 +00:00
# 计算多头持有期内的最高价,以及重置最低价
self.intraTradeHigh = max(self.intraTradeHigh, bar.high)
self.intraTradeLow = bar.low
# 计算多头移动止损
longStop = self.intraTradeHigh * (1-self.trailingPercent/100)
# 发出本地止损委托,并且把委托号记录下来,用于后续撤单
orderID = self.sell(longStop, 1, stop=True)
self.orderList.append(orderID)
# 持有空头仓位
elif self.pos < 0:
2016-05-23 11:20:27 +00:00
self.intraTradeLow = min(self.intraTradeLow, bar.low)
self.intraTradeHigh = bar.high
shortStop = self.intraTradeLow * (1+self.trailingPercent/100)
orderID = self.cover(shortStop, 1, stop=True)
self.orderList.append(orderID)
# 发出状态更新事件
self.putEvent()
#----------------------------------------------------------------------
def onOrder(self, order):
"""收到委托变化推送(必须由用户继承实现)"""
pass
#----------------------------------------------------------------------
def onTrade(self, trade):
# 发出状态更新事件
self.putEvent()
2016-05-23 11:20:27 +00:00
if __name__ == '__main__':
# 提供直接双击回测的功能
# 导入PyQt4的包是为了保证matplotlib使用PyQt4而不是PySide防止初始化出错
from ctaBacktesting import *
from PyQt4 import QtCore, QtGui
# 创建回测引擎
engine = BacktestingEngine()
# 设置引擎的回测模式为K线
engine.setBacktestingMode(engine.BAR_MODE)
# 设置回测用的数据起始日期
engine.setStartDate('20120101')
# 设置产品相关参数
engine.setSlippage(0.2) # 股指1跳
engine.setRate(0.3/10000) # 万0.3
engine.setSize(300) # 股指合约大小
# 设置使用的历史数据库
engine.setDatabase(MINUTE_DB_NAME, 'IF0000')
## 在引擎中创建策略对象
#d = {'atrLength': 11}
#engine.initStrategy(AtrRsiStrategy, d)
2016-05-23 11:20:27 +00:00
## 开始跑回测
##engine.runBacktesting()
2016-05-23 11:20:27 +00:00
## 显示回测结果
##engine.showBacktestingResult()
2016-05-23 11:20:27 +00:00
# 跑优化
setting = OptimizationSetting() # 新建一个优化任务设置对象
setting.setOptimizeTarget('capital') # 设置优化排序的目标是策略净盈利
setting.addParameter('atrLength', 12, 20, 2) # 增加第一个优化参数atrLength起始11结束12步进1
setting.addParameter('atrMa', 20, 30, 5) # 增加第二个优化参数atrMa起始20结束30步进1
setting.addParameter('rsiLength', 5) # 增加一个固定数值的参数
2016-05-23 11:20:27 +00:00
# 性能测试环境I7-3770主频3.4G, 8核心内存16GWindows 7 专业版
# 测试时还跑着一堆其他的程序,性能仅供参考
import time
start = time.time()
# 运行单进程优化函数自动输出结果耗时359秒
engine.runOptimization(AtrRsiStrategy, setting)
# 多进程优化耗时89秒
#engine.runParallelOptimization(AtrRsiStrategy, setting)
print u'耗时:%s' %(time.time()-start)