vnpy/vn.trader/ctaStrategy/strategy/strategyAtrRsi.py
chenxy123 9f6ea1fd5c 1. ctaAlgo模块改名ctaStrategy
2. 将vn.trader下的所有接口统一放到了一个单独的gateway文件夹中,并实现自动识别和加载
2017-03-18 12:19:51 +08:00

291 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# encoding: UTF-8
"""
一个ATR-RSI指标结合的交易策略适合用在股指的1分钟和5分钟线上。
注意事项:
1. 作者不对交易盈利做任何保证,策略代码仅供参考
2. 本策略需要用到talib没有安装的用户请先参考www.vnpy.org上的教程安装
3. 将IF0000_1min.csv用ctaHistoryData.py导入MongoDB后直接运行本文件即可回测策略
"""
from ctaBase import *
from ctaTemplate import CtaTemplate
import talib
import numpy as np
########################################################################
class AtrRsiStrategy(CtaTemplate):
"""结合ATR和RSI指标的一个分钟线交易策略"""
className = 'AtrRsiStrategy'
author = u'用Python的交易员'
# 策略参数
atrLength = 22 # 计算ATR指标的窗口数
atrMaLength = 10 # 计算ATR均线的窗口数
rsiLength = 5 # 计算RSI的窗口数
rsiEntry = 16 # RSI的开仓信号
trailingPercent = 0.8 # 百分比移动止损
initDays = 10 # 初始化数据所用的天数
fixedSize = 1 # 每次交易的数量
# 策略变量
bar = None # K线对象
barMinute = EMPTY_STRING # K线当前的分钟
bufferSize = 100 # 需要缓存的数据的大小
bufferCount = 0 # 目前已经缓存了的数据的计数
highArray = np.zeros(bufferSize) # K线最高价的数组
lowArray = np.zeros(bufferSize) # K线最低价的数组
closeArray = np.zeros(bufferSize) # K线收盘价的数组
atrCount = 0 # 目前已经缓存了的ATR的计数
atrArray = np.zeros(bufferSize) # ATR指标的数组
atrValue = 0 # 最新的ATR指标数值
atrMa = 0 # ATR移动平均的数值
rsiValue = 0 # RSI指标的数值
rsiBuy = 0 # RSI买开阈值
rsiSell = 0 # RSI卖开阈值
intraTradeHigh = 0 # 移动止损用的持仓期内最高价
intraTradeLow = 0 # 移动止损用的持仓期内最低价
orderList = [] # 保存委托代码的列表
# 参数列表,保存了参数的名称
paramList = ['name',
'className',
'author',
'vtSymbol',
'atrLength',
'atrMaLength',
'rsiLength',
'rsiEntry',
'trailingPercent']
# 变量列表,保存了变量的名称
varList = ['inited',
'trading',
'pos',
'atrValue',
'atrMa',
'rsiValue',
'rsiBuy',
'rsiSell']
#----------------------------------------------------------------------
def __init__(self, ctaEngine, setting):
"""Constructor"""
super(AtrRsiStrategy, self).__init__(ctaEngine, setting)
# 注意策略类中的可变对象属性通常是list和dict等在策略初始化时需要重新创建
# 否则会出现多个策略实例之间数据共享的情况,有可能导致潜在的策略逻辑错误风险,
# 策略类中的这些可变对象属性可以选择不写全都放在__init__下面写主要是为了阅读
# 策略时方便(更多是个编程习惯的选择)
#----------------------------------------------------------------------
def onInit(self):
"""初始化策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略初始化' %self.name)
# 初始化RSI入场阈值
self.rsiBuy = 50 + self.rsiEntry
self.rsiSell = 50 - self.rsiEntry
# 载入历史数据,并采用回放计算的方式初始化策略数值
initData = self.loadBar(self.initDays)
for bar in initData:
self.onBar(bar)
self.putEvent()
#----------------------------------------------------------------------
def onStart(self):
"""启动策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略启动' %self.name)
self.putEvent()
#----------------------------------------------------------------------
def onStop(self):
"""停止策略(必须由用户继承实现)"""
self.writeCtaLog(u'%s策略停止' %self.name)
self.putEvent()
#----------------------------------------------------------------------
def onTick(self, tick):
"""收到行情TICK推送必须由用户继承实现"""
# 计算K线
tickMinute = tick.datetime.minute
if tickMinute != self.barMinute:
if self.bar:
self.onBar(self.bar)
bar = CtaBarData()
bar.vtSymbol = tick.vtSymbol
bar.symbol = tick.symbol
bar.exchange = tick.exchange
bar.open = tick.lastPrice
bar.high = tick.lastPrice
bar.low = tick.lastPrice
bar.close = tick.lastPrice
bar.date = tick.date
bar.time = tick.time
bar.datetime = tick.datetime # K线的时间设为第一个Tick的时间
self.bar = bar # 这种写法为了减少一层访问,加快速度
self.barMinute = tickMinute # 更新当前的分钟
else: # 否则继续累加新的K线
bar = self.bar # 写法同样为了加快速度
bar.high = max(bar.high, tick.lastPrice)
bar.low = min(bar.low, tick.lastPrice)
bar.close = tick.lastPrice
#----------------------------------------------------------------------
def onBar(self, bar):
"""收到Bar推送必须由用户继承实现"""
# 撤销之前发出的尚未成交的委托(包括限价单和停止单)
for orderID in self.orderList:
self.cancelOrder(orderID)
self.orderList = []
# 保存K线数据
self.closeArray[0:self.bufferSize-1] = self.closeArray[1:self.bufferSize]
self.highArray[0:self.bufferSize-1] = self.highArray[1:self.bufferSize]
self.lowArray[0:self.bufferSize-1] = self.lowArray[1:self.bufferSize]
self.closeArray[-1] = bar.close
self.highArray[-1] = bar.high
self.lowArray[-1] = bar.low
self.bufferCount += 1
if self.bufferCount < self.bufferSize:
return
# 计算指标数值
self.atrValue = talib.ATR(self.highArray,
self.lowArray,
self.closeArray,
self.atrLength)[-1]
self.atrArray[0:self.bufferSize-1] = self.atrArray[1:self.bufferSize]
self.atrArray[-1] = self.atrValue
self.atrCount += 1
if self.atrCount < self.bufferSize:
return
self.atrMa = talib.MA(self.atrArray,
self.atrMaLength)[-1]
self.rsiValue = talib.RSI(self.closeArray,
self.rsiLength)[-1]
# 判断是否要进行交易
# 当前无仓位
if self.pos == 0:
self.intraTradeHigh = bar.high
self.intraTradeLow = bar.low
# ATR数值上穿其移动平均线说明行情短期内波动加大
# 即处于趋势的概率较大适合CTA开仓
if self.atrValue > self.atrMa:
# 使用RSI指标的趋势行情时会在超买超卖区钝化特征作为开仓信号
if self.rsiValue > self.rsiBuy:
# 这里为了保证成交选择超价5个整指数点下单
self.buy(bar.close+5, self.fixedSize)
elif self.rsiValue < self.rsiSell:
self.short(bar.close-5, self.fixedSize)
# 持有多头仓位
elif self.pos > 0:
# 计算多头持有期内的最高价,以及重置最低价
self.intraTradeHigh = max(self.intraTradeHigh, bar.high)
self.intraTradeLow = bar.low
# 计算多头移动止损
longStop = self.intraTradeHigh * (1-self.trailingPercent/100)
# 发出本地止损委托,并且把委托号记录下来,用于后续撤单
orderID = self.sell(longStop, abs(self.pos), stop=True)
self.orderList.append(orderID)
# 持有空头仓位
elif self.pos < 0:
self.intraTradeLow = min(self.intraTradeLow, bar.low)
self.intraTradeHigh = bar.high
shortStop = self.intraTradeLow * (1+self.trailingPercent/100)
orderID = self.cover(shortStop, abs(self.pos), stop=True)
self.orderList.append(orderID)
# 发出状态更新事件
self.putEvent()
#----------------------------------------------------------------------
def onOrder(self, order):
"""收到委托变化推送(必须由用户继承实现)"""
pass
#----------------------------------------------------------------------
def onTrade(self, trade):
# 发出状态更新事件
self.putEvent()
if __name__ == '__main__':
# 提供直接双击回测的功能
# 导入PyQt4的包是为了保证matplotlib使用PyQt4而不是PySide防止初始化出错
from ctaBacktesting import *
from PyQt4 import QtCore, QtGui
# 创建回测引擎
engine = BacktestingEngine()
# 设置引擎的回测模式为K线
engine.setBacktestingMode(engine.BAR_MODE)
# 设置回测用的数据起始日期
engine.setStartDate('20120101')
# 设置产品相关参数
engine.setSlippage(0.2) # 股指1跳
engine.setRate(0.3/10000) # 万0.3
engine.setSize(300) # 股指合约大小
# 设置使用的历史数据库
engine.setDatabase(MINUTE_DB_NAME, 'IF0000')
# 在引擎中创建策略对象
d = {'atrLength': 11}
engine.initStrategy(AtrRsiStrategy, d)
# 开始跑回测
engine.runBacktesting()
# 显示回测结果
engine.showBacktestingResult()
## 跑优化
#setting = OptimizationSetting() # 新建一个优化任务设置对象
#setting.setOptimizeTarget('capital') # 设置优化排序的目标是策略净盈利
#setting.addParameter('atrLength', 12, 20, 2) # 增加第一个优化参数atrLength起始11结束12步进1
#setting.addParameter('atrMa', 20, 30, 5) # 增加第二个优化参数atrMa起始20结束30步进1
#setting.addParameter('rsiLength', 5) # 增加一个固定数值的参数
## 性能测试环境I7-3770主频3.4G, 8核心内存16GWindows 7 专业版
## 测试时还跑着一堆其他的程序,性能仅供参考
#import time
#start = time.time()
## 运行单进程优化函数自动输出结果耗时359秒
#engine.runOptimization(AtrRsiStrategy, setting)
## 多进程优化耗时89秒
##engine.runParallelOptimization(AtrRsiStrategy, setting)
#print u'耗时:%s' %(time.time()-start)