[Add]增加K线序列管理器,用于简化技术指标相关数据的计算写法

This commit is contained in:
vn.py 2017-10-07 18:25:33 +08:00
parent 93ee7c35e9
commit b901e579ae
3 changed files with 133 additions and 88 deletions

File diff suppressed because one or more lines are too long

View File

@ -4,6 +4,9 @@
本文件包含了CTA引擎中的策略开发用模板开发策略时需要继承CtaTemplate类 本文件包含了CTA引擎中的策略开发用模板开发策略时需要继承CtaTemplate类
''' '''
import numpy as np
import talib
from vnpy.trader.vtConstant import * from vnpy.trader.vtConstant import *
from vnpy.trader.vtObject import VtBarData from vnpy.trader.vtObject import VtBarData
@ -410,3 +413,107 @@ class BarManager(object):
# 清空老K线缓存对象 # 清空老K线缓存对象
self.xminBar = None self.xminBar = None
########################################################################
class ArrayManager(object):
"""
K线序列管理工具负责
1. K线时间序列的维护
2. 常用技术指标的计算
"""
#----------------------------------------------------------------------
def __init__(self, size=100):
"""Constructor"""
self.count = 0 # 缓存计数
self.size = size # 缓存大小
self.inited = False # True if count>=size
self.openArray = np.zeros(size) # OHLC
self.highArray = np.zeros(size)
self.lowArray = np.zeros(size)
self.closeArray = np.zeros(size)
self.volumeArray = np.zeros(size)
#----------------------------------------------------------------------
def updateBar(self, bar):
"""更新K线"""
self.count += 1
if not self.inited and self.count >= self.size:
self.inited = True
self.openArray[0:self.size-1] = self.closeArray[1:self.size]
self.highArray[0:self.size-1] = self.highArray[1:self.size]
self.lowArray[0:self.size-1] = self.lowArray[1:self.size]
self.closeArray[0:self.size-1] = self.closeArray[1:self.size]
self.volumeArray[0:self.size-1] = self.volumeArray[1:self.size]
self.openArray[-1] = bar.open
self.highArray[-1] = bar.high
self.lowArray[-1] = bar.low
self.closeArray[-1] = bar.close
self.volumeArray[-1] = bar.volume
#----------------------------------------------------------------------
@property
def open(self):
"""获取开盘价序列"""
return self.openArray
#----------------------------------------------------------------------
@property
def high(self):
"""获取最高价序列"""
return self.highArray
#----------------------------------------------------------------------
@property
def low(self):
"""获取最低价序列"""
return self.lowArray
#----------------------------------------------------------------------
@property
def close(self):
"""获取收盘价序列"""
return self.closeArray
#----------------------------------------------------------------------
@property
def volume(self):
"""获取成交量序列"""
return self.volumeArray
#----------------------------------------------------------------------
def sma(self, n, shift=0):
"""均线"""
return talib.SMA(self.close, n)[-1-shift]
#----------------------------------------------------------------------
def std(self, n, shift=0):
"""标准差"""
return talib.STDDEV(self.close, n)[-1-shift]
#----------------------------------------------------------------------
def boll(self, n, dev, shift=0):
"""布林通道"""
mid = self.sma(n, shift)
std = self.std(n, shift)
up = mid + std * dev
down = mid - std * dev
return up, down
#----------------------------------------------------------------------
def cci(self, n, shift=0):
"""CCI指标"""
return talib.CCI(self.high, self.low, self.close, n)[-1-shift]
#----------------------------------------------------------------------
def atr(self, n, shift=0):
"""ATR指标"""
return talib.ATR(self.high, self.low, self.close, n)[-1-shift]

View File

@ -23,7 +23,9 @@ import numpy as np
from vnpy.trader.vtObject import VtBarData from vnpy.trader.vtObject import VtBarData
from vnpy.trader.vtConstant import EMPTY_STRING from vnpy.trader.vtConstant import EMPTY_STRING
from vnpy.trader.app.ctaStrategy.ctaTemplate import CtaTemplate, BarManager from vnpy.trader.app.ctaStrategy.ctaTemplate import (CtaTemplate,
BarManager,
ArrayManager)
######################################################################## ########################################################################
@ -42,14 +44,6 @@ class BollChannelStrategy(CtaTemplate):
fixedSize = 1 # 每次交易的数量 fixedSize = 1 # 每次交易的数量
# 策略变量 # 策略变量
bufferSize = 100 # 需要缓存的数据的大小
bufferCount = 0 # 目前已经缓存了的数据的计数
highArray = np.zeros(bufferSize) # K线最高价的数组
lowArray = np.zeros(bufferSize) # K线最低价的数组
closeArray = np.zeros(bufferSize) # K线收盘价的数组
bollMid = 0 # 布林通道中轨
bollStd = 0 # 布林通道标准差
bollUp = 0 # 布林通道上轨 bollUp = 0 # 布林通道上轨
bollDown = 0 # 布林通道下轨 bollDown = 0 # 布林通道下轨
cciValue = 0 # CCI指标数值 cciValue = 0 # CCI指标数值
@ -79,8 +73,6 @@ class BollChannelStrategy(CtaTemplate):
varList = ['inited', varList = ['inited',
'trading', 'trading',
'pos', 'pos',
'bollMid',
'bollStd',
'bollUp', 'bollUp',
'bollDown', 'bollDown',
'cciValue', 'cciValue',
@ -96,6 +88,7 @@ class BollChannelStrategy(CtaTemplate):
super(BollChannelStrategy, self).__init__(ctaEngine, setting) super(BollChannelStrategy, self).__init__(ctaEngine, setting)
self.bm = BarManager(self.onBar, 15, self.onXminBar) # 创建K线合成器对象 self.bm = BarManager(self.onBar, 15, self.onXminBar) # 创建K线合成器对象
self.am = ArrayManager()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def onInit(self): def onInit(self):
@ -140,36 +133,21 @@ class BollChannelStrategy(CtaTemplate):
self.orderList = [] self.orderList = []
# 保存K线数据 # 保存K线数据
self.closeArray[0:self.bufferSize-1] = self.closeArray[1:self.bufferSize] am = self.am
self.highArray[0:self.bufferSize-1] = self.highArray[1:self.bufferSize]
self.lowArray[0:self.bufferSize-1] = self.lowArray[1:self.bufferSize]
self.closeArray[-1] = bar.close am.updateBar(bar)
self.highArray[-1] = bar.high
self.lowArray[-1] = bar.low
self.bufferCount += 1 if not am.inited:
if self.bufferCount < self.bufferSize:
return return
# 计算指标数值 # 计算指标数值
self.bollMid = self.closeArray[-self.bollWindow:].mean() self.bollUp, self.bollDown = am.boll(self.bollWindow, self.bollDev)
self.bollStd = self.closeArray[-self.bollWindow:].std() self.cciValue = am.cci(self.cciWindow)
self.bollUp = self.bollMid + self.bollStd * self.bollDev self.atrValue = am.atr(self.atrWindow)
self.bollDown = self.bollMid - self.bollStd * self.bollDev
self.cciValue = talib.CCI(self.highArray,
self.lowArray,
self.closeArray,
self.cciWindow)[-1]
self.atrValue = talib.ATR(self.highArray,
self.lowArray,
self.closeArray,
self.atrWindow)[-1]
# 判断是否要进行交易 # 判断是否要进行交易
# 当前无仓位,发送OCO开仓委托 # 当前无仓位,发送开仓委托
if self.pos == 0: if self.pos == 0:
self.intraTradeHigh = bar.high self.intraTradeHigh = bar.high
self.intraTradeLow = bar.low self.intraTradeLow = bar.low
@ -216,3 +194,4 @@ class BollChannelStrategy(CtaTemplate):
def onStopOrder(self, so): def onStopOrder(self, so):
"""停止单推送""" """停止单推送"""
pass pass