[Add]增加K线序列管理器,用于简化技术指标相关数据的计算写法
This commit is contained in:
parent
93ee7c35e9
commit
b901e579ae
File diff suppressed because one or more lines are too long
@ -4,6 +4,9 @@
|
||||
本文件包含了CTA引擎中的策略开发用模板,开发策略时需要继承CtaTemplate类。
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
import talib
|
||||
|
||||
from vnpy.trader.vtConstant import *
|
||||
from vnpy.trader.vtObject import VtBarData
|
||||
|
||||
@ -410,3 +413,107 @@ class BarManager(object):
|
||||
|
||||
# 清空老K线缓存对象
|
||||
self.xminBar = None
|
||||
|
||||
|
||||
########################################################################
|
||||
class ArrayManager(object):
|
||||
"""
|
||||
K线序列管理工具,负责:
|
||||
1. K线时间序列的维护
|
||||
2. 常用技术指标的计算
|
||||
"""
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self, size=100):
|
||||
"""Constructor"""
|
||||
self.count = 0 # 缓存计数
|
||||
self.size = size # 缓存大小
|
||||
self.inited = False # True if count>=size
|
||||
|
||||
self.openArray = np.zeros(size) # OHLC
|
||||
self.highArray = np.zeros(size)
|
||||
self.lowArray = np.zeros(size)
|
||||
self.closeArray = np.zeros(size)
|
||||
self.volumeArray = np.zeros(size)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def updateBar(self, bar):
|
||||
"""更新K线"""
|
||||
self.count += 1
|
||||
if not self.inited and self.count >= self.size:
|
||||
self.inited = True
|
||||
|
||||
self.openArray[0:self.size-1] = self.closeArray[1:self.size]
|
||||
self.highArray[0:self.size-1] = self.highArray[1:self.size]
|
||||
self.lowArray[0:self.size-1] = self.lowArray[1:self.size]
|
||||
self.closeArray[0:self.size-1] = self.closeArray[1:self.size]
|
||||
self.volumeArray[0:self.size-1] = self.volumeArray[1:self.size]
|
||||
|
||||
self.openArray[-1] = bar.open
|
||||
self.highArray[-1] = bar.high
|
||||
self.lowArray[-1] = bar.low
|
||||
self.closeArray[-1] = bar.close
|
||||
self.volumeArray[-1] = bar.volume
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@property
|
||||
def open(self):
|
||||
"""获取开盘价序列"""
|
||||
return self.openArray
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@property
|
||||
def high(self):
|
||||
"""获取最高价序列"""
|
||||
return self.highArray
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@property
|
||||
def low(self):
|
||||
"""获取最低价序列"""
|
||||
return self.lowArray
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@property
|
||||
def close(self):
|
||||
"""获取收盘价序列"""
|
||||
return self.closeArray
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@property
|
||||
def volume(self):
|
||||
"""获取成交量序列"""
|
||||
return self.volumeArray
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def sma(self, n, shift=0):
|
||||
"""均线"""
|
||||
return talib.SMA(self.close, n)[-1-shift]
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def std(self, n, shift=0):
|
||||
"""标准差"""
|
||||
return talib.STDDEV(self.close, n)[-1-shift]
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def boll(self, n, dev, shift=0):
|
||||
"""布林通道"""
|
||||
mid = self.sma(n, shift)
|
||||
std = self.std(n, shift)
|
||||
|
||||
up = mid + std * dev
|
||||
down = mid - std * dev
|
||||
|
||||
return up, down
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def cci(self, n, shift=0):
|
||||
"""CCI指标"""
|
||||
return talib.CCI(self.high, self.low, self.close, n)[-1-shift]
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def atr(self, n, shift=0):
|
||||
"""ATR指标"""
|
||||
return talib.ATR(self.high, self.low, self.close, n)[-1-shift]
|
||||
|
||||
|
@ -23,7 +23,9 @@ import numpy as np
|
||||
|
||||
from vnpy.trader.vtObject import VtBarData
|
||||
from vnpy.trader.vtConstant import EMPTY_STRING
|
||||
from vnpy.trader.app.ctaStrategy.ctaTemplate import CtaTemplate, BarManager
|
||||
from vnpy.trader.app.ctaStrategy.ctaTemplate import (CtaTemplate,
|
||||
BarManager,
|
||||
ArrayManager)
|
||||
|
||||
|
||||
########################################################################
|
||||
@ -42,14 +44,6 @@ class BollChannelStrategy(CtaTemplate):
|
||||
fixedSize = 1 # 每次交易的数量
|
||||
|
||||
# 策略变量
|
||||
bufferSize = 100 # 需要缓存的数据的大小
|
||||
bufferCount = 0 # 目前已经缓存了的数据的计数
|
||||
highArray = np.zeros(bufferSize) # K线最高价的数组
|
||||
lowArray = np.zeros(bufferSize) # K线最低价的数组
|
||||
closeArray = np.zeros(bufferSize) # K线收盘价的数组
|
||||
|
||||
bollMid = 0 # 布林通道中轨
|
||||
bollStd = 0 # 布林通道标准差
|
||||
bollUp = 0 # 布林通道上轨
|
||||
bollDown = 0 # 布林通道下轨
|
||||
cciValue = 0 # CCI指标数值
|
||||
@ -79,8 +73,6 @@ class BollChannelStrategy(CtaTemplate):
|
||||
varList = ['inited',
|
||||
'trading',
|
||||
'pos',
|
||||
'bollMid',
|
||||
'bollStd',
|
||||
'bollUp',
|
||||
'bollDown',
|
||||
'cciValue',
|
||||
@ -96,6 +88,7 @@ class BollChannelStrategy(CtaTemplate):
|
||||
super(BollChannelStrategy, self).__init__(ctaEngine, setting)
|
||||
|
||||
self.bm = BarManager(self.onBar, 15, self.onXminBar) # 创建K线合成器对象
|
||||
self.am = ArrayManager()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onInit(self):
|
||||
@ -140,36 +133,21 @@ class BollChannelStrategy(CtaTemplate):
|
||||
self.orderList = []
|
||||
|
||||
# 保存K线数据
|
||||
self.closeArray[0:self.bufferSize-1] = self.closeArray[1:self.bufferSize]
|
||||
self.highArray[0:self.bufferSize-1] = self.highArray[1:self.bufferSize]
|
||||
self.lowArray[0:self.bufferSize-1] = self.lowArray[1:self.bufferSize]
|
||||
am = self.am
|
||||
|
||||
self.closeArray[-1] = bar.close
|
||||
self.highArray[-1] = bar.high
|
||||
self.lowArray[-1] = bar.low
|
||||
am.updateBar(bar)
|
||||
|
||||
self.bufferCount += 1
|
||||
if self.bufferCount < self.bufferSize:
|
||||
if not am.inited:
|
||||
return
|
||||
|
||||
# 计算指标数值
|
||||
self.bollMid = self.closeArray[-self.bollWindow:].mean()
|
||||
self.bollStd = self.closeArray[-self.bollWindow:].std()
|
||||
self.bollUp = self.bollMid + self.bollStd * self.bollDev
|
||||
self.bollDown = self.bollMid - self.bollStd * self.bollDev
|
||||
|
||||
self.cciValue = talib.CCI(self.highArray,
|
||||
self.lowArray,
|
||||
self.closeArray,
|
||||
self.cciWindow)[-1]
|
||||
self.atrValue = talib.ATR(self.highArray,
|
||||
self.lowArray,
|
||||
self.closeArray,
|
||||
self.atrWindow)[-1]
|
||||
self.bollUp, self.bollDown = am.boll(self.bollWindow, self.bollDev)
|
||||
self.cciValue = am.cci(self.cciWindow)
|
||||
self.atrValue = am.atr(self.atrWindow)
|
||||
|
||||
# 判断是否要进行交易
|
||||
|
||||
# 当前无仓位,发送OCO开仓委托
|
||||
# 当前无仓位,发送开仓委托
|
||||
if self.pos == 0:
|
||||
self.intraTradeHigh = bar.high
|
||||
self.intraTradeLow = bar.low
|
||||
@ -216,3 +194,4 @@ class BollChannelStrategy(CtaTemplate):
|
||||
def onStopOrder(self, so):
|
||||
"""停止单推送"""
|
||||
pass
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user