vnpy/vn.trader/ctaAlgo/strategyAtrRsi.py
2016-07-02 11:12:56 +08:00

270 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# encoding: UTF-8
"""
一个ATR-RSI指标结合的交易策略适合用在股指的1分钟和5分钟线上。
注意事项:
1. 作者不对交易盈利做任何保证,策略代码仅供参考
2. 本策略需要用到talib没有安装的用户请先参考www.vnpy.org上的教程安装
3. 将IF0000_1min.csv用ctaHistoryData.py导入MongoDB后直接运行本文件即可回测策略
"""
from ctaBase import *
from ctaTemplate import CtaTemplate
import talib
import numpy as np
########################################################################
class AtrRsiStrategy(CtaTemplate):
"""结合ATR和RSI指标的一个分钟线交易策略"""
className = 'AtrRsiStrategy'
author = u'用Python的交易员'
# 策略参数
atrLength = 22 # 计算ATR指标的窗口数
atrMaLength = 10 # 计算ATR均线的窗口数
rsiLength = 5 # 计算RSI的窗口数
rsiEntry = 16 # RSI的开仓信号
trailingPercent = 0.8 # 百分比移动止损
initDays = 10 # 初始化数据所用的天数
# 策略变量
bar = None # K线对象
barMinute = EMPTY_STRING # K线当前的分钟
bufferSize = 100 # 需要缓存的数据的大小
bufferCount = 0 # 目前已经缓存了的数据的计数
highArray = np.zeros(bufferSize) # K线最高价的数组
lowArray = np.zeros(bufferSize) # K线最低价的数组
closeArray = np.zeros(bufferSize) # K线收盘价的数组
atrCount = 0 # 目前已经缓存了的ATR的计数
atrArray = np.zeros(bufferSize) # ATR指标的数组
atrValue = 0 # 最新的ATR指标数值
atrMa = 0 # ATR移动平均的数值
rsiValue = 0 # RSI指标的数值
rsiBuy = 0 # RSI买开阈值
rsiSell = 0 # RSI卖开阈值
intraTradeHigh = 0 # 移动止损用的持仓期内最高价
intraTradeLow = 0 # 移动止损用的持仓期内最低价
orderList = [] # 保存委托代码的列表
# 参数列表,保存了参数的名称
paramList = ['name',
'className',
'author',
'vtSymbol',
'atrLength',
'atrMaLength',
'rsiLength',
'rsiEntry',
'trailingPercent']
# 变量列表,保存了变量的名称
varList = ['inited',
'trading',
'pos',
'atrValue',
'atrMa',
'rsiValue',
'rsiBuy',
'rsiSell']
#----------------------------------------------------------------------
def __init__(self, ctaEngine, setting):
"""Constructor"""
super(AtrRsiStrategy, self).__init__(ctaEngine, setting)
#----------------------------------------------------------------------
def onInit(self):
"""初始化策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略初始化' %self.name)
# 初始化RSI入场阈值
self.rsiBuy = 50 + self.rsiEntry
self.rsiSell = 50 - self.rsiEntry
# 载入历史数据,并采用回放计算的方式初始化策略数值
initData = self.loadBar(self.initDays)
for bar in initData:
self.onBar(bar)
self.putEvent()
#----------------------------------------------------------------------
def onStart(self):
"""启动策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略启动' %self.name)
self.putEvent()
#----------------------------------------------------------------------
def onStop(self):
"""停止策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略停止' %self.name)
self.putEvent()
#----------------------------------------------------------------------
def onTick(self, tick):
"""收到行情TICK推送必须由用户继承实现"""
# 计算K线
tickMinute = tick.datetime.minute
if tickMinute != self.barMinute:
if self.bar:
self.onBar(self.bar)
bar = CtaBarData()
bar.vtSymbol = tick.vtSymbol
bar.symbol = tick.symbol
bar.exchange = tick.exchange
bar.open = tick.lastPrice
bar.high = tick.lastPrice
bar.low = tick.lastPrice
bar.close = tick.lastPrice
bar.date = tick.date
bar.time = tick.time
bar.datetime = tick.datetime # K线的时间设为第一个Tick的时间
self.bar = bar # 这种写法为了减少一层访问,加快速度
self.barMinute = tickMinute # 更新当前的分钟
else: # 否则继续累加新的K线
bar = self.bar # 写法同样为了加快速度
bar.high = max(bar.high, tick.lastPrice)
bar.low = min(bar.low, tick.lastPrice)
bar.close = tick.lastPrice
#----------------------------------------------------------------------
def onBar(self, bar):
"""收到Bar推送必须由用户继承实现"""
# 撤销之前发出的尚未成交的委托(包括限价单和停止单)
for orderID in self.orderList:
self.cancelOrder(orderID)
self.orderList = []
# 保存K线数据
self.closeArray[0:self.bufferSize-1] = self.closeArray[1:self.bufferSize]
self.highArray[0:self.bufferSize-1] = self.highArray[1:self.bufferSize]
self.lowArray[0:self.bufferSize-1] = self.lowArray[1:self.bufferSize]
self.closeArray[-1] = bar.close
self.highArray[-1] = bar.high
self.lowArray[-1] = bar.low
self.bufferCount += 1
if self.bufferCount < self.bufferSize:
return
# 计算指标数值
self.atrValue = talib.ATR(self.highArray,
self.lowArray,
self.closeArray,
self.atrLength)[-1]
self.atrArray[0:self.bufferSize-1] = self.atrArray[1:self.bufferSize]
self.atrArray[-1] = self.atrValue
self.atrCount += 1
if self.atrCount < self.bufferSize:
return
self.atrMa = talib.MA(self.atrArray,
self.atrMaLength)[-1]
self.rsiValue = talib.RSI(self.closeArray,
self.rsiLength)[-1]
# 判断是否要进行交易
# 当前无仓位
if self.pos == 0:
self.intraTradeHigh = bar.high
self.intraTradeLow = bar.low
# ATR数值上穿其移动平均线说明行情短期内波动加大
# 即处于趋势的概率较大适合CTA开仓
if self.atrValue > self.atrMa:
# 使用RSI指标的趋势行情时会在超买超卖区钝化特征作为开仓信号
if self.rsiValue > self.rsiBuy:
# 这里为了保证成交选择超价5个整指数点下单
self.buy(bar.close+5, 1)
return
if self.rsiValue < self.rsiSell:
self.short(bar.close-5, 1)
return
# 持有多头仓位
if self.pos == 1:
# 计算多头持有期内的最高价,以及重置最低价
self.intraTradeHigh = max(self.intraTradeHigh, bar.high)
self.intraTradeLow = bar.low
# 计算多头移动止损
longStop = self.intraTradeHigh * (1-self.trailingPercent/100)
# 发出本地止损委托,并且把委托号记录下来,用于后续撤单
orderID = self.sell(longStop, 1, stop=True)
self.orderList.append(orderID)
return
# 持有空头仓位
if self.pos == -1:
self.intraTradeLow = min(self.intraTradeLow, bar.low)
self.intraTradeHigh = bar.high
shortStop = self.intraTradeLow * (1+self.trailingPercent/100)
orderID = self.cover(shortStop, 1, stop=True)
self.orderList.append(orderID)
return
# 发出状态更新事件
self.putEvent()
#----------------------------------------------------------------------
def onOrder(self, order):
"""收到委托变化推送(必须由用户继承实现)"""
pass
#----------------------------------------------------------------------
def onTrade(self, trade):
pass
if __name__ == '__main__':
# 提供直接双击回测的功能
# 导入PyQt4的包是为了保证matplotlib使用PyQt4而不是PySide防止初始化出错
from ctaBacktesting import *
from PyQt4 import QtCore, QtGui
# 创建回测引擎
engine = BacktestingEngine()
# 设置引擎的回测模式为K线
engine.setBacktestingMode(engine.BAR_MODE)
# 设置回测用的数据起始日期
engine.setStartDate('20120101')
# 载入历史数据到引擎中
engine.loadHistoryData(MINUTE_DB_NAME, 'IF0000')
# 设置产品相关参数
engine.setSlippage(0.2) # 股指1跳
engine.setRate(0.3/10000) # 万0.3
engine.setSize(300) # 股指合约大小
# 在引擎中创建策略对象
engine.initStrategy(AtrRsiStrategy, {})
# 开始跑回测
engine.runBacktesting()
# 显示回测结果
engine.showBacktestingResult()