[Add]增加K线序列管理器,用于简化技术指标相关数据的计算写法
This commit is contained in:
parent
93ee7c35e9
commit
b901e579ae
File diff suppressed because one or more lines are too long
@ -4,6 +4,9 @@
|
|||||||
本文件包含了CTA引擎中的策略开发用模板,开发策略时需要继承CtaTemplate类。
|
本文件包含了CTA引擎中的策略开发用模板,开发策略时需要继承CtaTemplate类。
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import talib
|
||||||
|
|
||||||
from vnpy.trader.vtConstant import *
|
from vnpy.trader.vtConstant import *
|
||||||
from vnpy.trader.vtObject import VtBarData
|
from vnpy.trader.vtObject import VtBarData
|
||||||
|
|
||||||
@ -410,3 +413,107 @@ class BarManager(object):
|
|||||||
|
|
||||||
# 清空老K线缓存对象
|
# 清空老K线缓存对象
|
||||||
self.xminBar = None
|
self.xminBar = None
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
class ArrayManager(object):
|
||||||
|
"""
|
||||||
|
K线序列管理工具,负责:
|
||||||
|
1. K线时间序列的维护
|
||||||
|
2. 常用技术指标的计算
|
||||||
|
"""
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def __init__(self, size=100):
|
||||||
|
"""Constructor"""
|
||||||
|
self.count = 0 # 缓存计数
|
||||||
|
self.size = size # 缓存大小
|
||||||
|
self.inited = False # True if count>=size
|
||||||
|
|
||||||
|
self.openArray = np.zeros(size) # OHLC
|
||||||
|
self.highArray = np.zeros(size)
|
||||||
|
self.lowArray = np.zeros(size)
|
||||||
|
self.closeArray = np.zeros(size)
|
||||||
|
self.volumeArray = np.zeros(size)
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def updateBar(self, bar):
|
||||||
|
"""更新K线"""
|
||||||
|
self.count += 1
|
||||||
|
if not self.inited and self.count >= self.size:
|
||||||
|
self.inited = True
|
||||||
|
|
||||||
|
self.openArray[0:self.size-1] = self.closeArray[1:self.size]
|
||||||
|
self.highArray[0:self.size-1] = self.highArray[1:self.size]
|
||||||
|
self.lowArray[0:self.size-1] = self.lowArray[1:self.size]
|
||||||
|
self.closeArray[0:self.size-1] = self.closeArray[1:self.size]
|
||||||
|
self.volumeArray[0:self.size-1] = self.volumeArray[1:self.size]
|
||||||
|
|
||||||
|
self.openArray[-1] = bar.open
|
||||||
|
self.highArray[-1] = bar.high
|
||||||
|
self.lowArray[-1] = bar.low
|
||||||
|
self.closeArray[-1] = bar.close
|
||||||
|
self.volumeArray[-1] = bar.volume
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def open(self):
|
||||||
|
"""获取开盘价序列"""
|
||||||
|
return self.openArray
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def high(self):
|
||||||
|
"""获取最高价序列"""
|
||||||
|
return self.highArray
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def low(self):
|
||||||
|
"""获取最低价序列"""
|
||||||
|
return self.lowArray
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def close(self):
|
||||||
|
"""获取收盘价序列"""
|
||||||
|
return self.closeArray
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@property
|
||||||
|
def volume(self):
|
||||||
|
"""获取成交量序列"""
|
||||||
|
return self.volumeArray
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def sma(self, n, shift=0):
|
||||||
|
"""均线"""
|
||||||
|
return talib.SMA(self.close, n)[-1-shift]
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def std(self, n, shift=0):
|
||||||
|
"""标准差"""
|
||||||
|
return talib.STDDEV(self.close, n)[-1-shift]
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def boll(self, n, dev, shift=0):
|
||||||
|
"""布林通道"""
|
||||||
|
mid = self.sma(n, shift)
|
||||||
|
std = self.std(n, shift)
|
||||||
|
|
||||||
|
up = mid + std * dev
|
||||||
|
down = mid - std * dev
|
||||||
|
|
||||||
|
return up, down
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def cci(self, n, shift=0):
|
||||||
|
"""CCI指标"""
|
||||||
|
return talib.CCI(self.high, self.low, self.close, n)[-1-shift]
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def atr(self, n, shift=0):
|
||||||
|
"""ATR指标"""
|
||||||
|
return talib.ATR(self.high, self.low, self.close, n)[-1-shift]
|
||||||
|
|
||||||
|
|
@ -23,7 +23,9 @@ import numpy as np
|
|||||||
|
|
||||||
from vnpy.trader.vtObject import VtBarData
|
from vnpy.trader.vtObject import VtBarData
|
||||||
from vnpy.trader.vtConstant import EMPTY_STRING
|
from vnpy.trader.vtConstant import EMPTY_STRING
|
||||||
from vnpy.trader.app.ctaStrategy.ctaTemplate import CtaTemplate, BarManager
|
from vnpy.trader.app.ctaStrategy.ctaTemplate import (CtaTemplate,
|
||||||
|
BarManager,
|
||||||
|
ArrayManager)
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
@ -42,14 +44,6 @@ class BollChannelStrategy(CtaTemplate):
|
|||||||
fixedSize = 1 # 每次交易的数量
|
fixedSize = 1 # 每次交易的数量
|
||||||
|
|
||||||
# 策略变量
|
# 策略变量
|
||||||
bufferSize = 100 # 需要缓存的数据的大小
|
|
||||||
bufferCount = 0 # 目前已经缓存了的数据的计数
|
|
||||||
highArray = np.zeros(bufferSize) # K线最高价的数组
|
|
||||||
lowArray = np.zeros(bufferSize) # K线最低价的数组
|
|
||||||
closeArray = np.zeros(bufferSize) # K线收盘价的数组
|
|
||||||
|
|
||||||
bollMid = 0 # 布林通道中轨
|
|
||||||
bollStd = 0 # 布林通道标准差
|
|
||||||
bollUp = 0 # 布林通道上轨
|
bollUp = 0 # 布林通道上轨
|
||||||
bollDown = 0 # 布林通道下轨
|
bollDown = 0 # 布林通道下轨
|
||||||
cciValue = 0 # CCI指标数值
|
cciValue = 0 # CCI指标数值
|
||||||
@ -79,8 +73,6 @@ class BollChannelStrategy(CtaTemplate):
|
|||||||
varList = ['inited',
|
varList = ['inited',
|
||||||
'trading',
|
'trading',
|
||||||
'pos',
|
'pos',
|
||||||
'bollMid',
|
|
||||||
'bollStd',
|
|
||||||
'bollUp',
|
'bollUp',
|
||||||
'bollDown',
|
'bollDown',
|
||||||
'cciValue',
|
'cciValue',
|
||||||
@ -96,6 +88,7 @@ class BollChannelStrategy(CtaTemplate):
|
|||||||
super(BollChannelStrategy, self).__init__(ctaEngine, setting)
|
super(BollChannelStrategy, self).__init__(ctaEngine, setting)
|
||||||
|
|
||||||
self.bm = BarManager(self.onBar, 15, self.onXminBar) # 创建K线合成器对象
|
self.bm = BarManager(self.onBar, 15, self.onXminBar) # 创建K线合成器对象
|
||||||
|
self.am = ArrayManager()
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def onInit(self):
|
def onInit(self):
|
||||||
@ -140,36 +133,21 @@ class BollChannelStrategy(CtaTemplate):
|
|||||||
self.orderList = []
|
self.orderList = []
|
||||||
|
|
||||||
# 保存K线数据
|
# 保存K线数据
|
||||||
self.closeArray[0:self.bufferSize-1] = self.closeArray[1:self.bufferSize]
|
am = self.am
|
||||||
self.highArray[0:self.bufferSize-1] = self.highArray[1:self.bufferSize]
|
|
||||||
self.lowArray[0:self.bufferSize-1] = self.lowArray[1:self.bufferSize]
|
|
||||||
|
|
||||||
self.closeArray[-1] = bar.close
|
|
||||||
self.highArray[-1] = bar.high
|
|
||||||
self.lowArray[-1] = bar.low
|
|
||||||
|
|
||||||
self.bufferCount += 1
|
|
||||||
if self.bufferCount < self.bufferSize:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 计算指标数值
|
|
||||||
self.bollMid = self.closeArray[-self.bollWindow:].mean()
|
|
||||||
self.bollStd = self.closeArray[-self.bollWindow:].std()
|
|
||||||
self.bollUp = self.bollMid + self.bollStd * self.bollDev
|
|
||||||
self.bollDown = self.bollMid - self.bollStd * self.bollDev
|
|
||||||
|
|
||||||
self.cciValue = talib.CCI(self.highArray,
|
am.updateBar(bar)
|
||||||
self.lowArray,
|
|
||||||
self.closeArray,
|
if not am.inited:
|
||||||
self.cciWindow)[-1]
|
return
|
||||||
self.atrValue = talib.ATR(self.highArray,
|
|
||||||
self.lowArray,
|
# 计算指标数值
|
||||||
self.closeArray,
|
self.bollUp, self.bollDown = am.boll(self.bollWindow, self.bollDev)
|
||||||
self.atrWindow)[-1]
|
self.cciValue = am.cci(self.cciWindow)
|
||||||
|
self.atrValue = am.atr(self.atrWindow)
|
||||||
|
|
||||||
# 判断是否要进行交易
|
# 判断是否要进行交易
|
||||||
|
|
||||||
# 当前无仓位,发送OCO开仓委托
|
# 当前无仓位,发送开仓委托
|
||||||
if self.pos == 0:
|
if self.pos == 0:
|
||||||
self.intraTradeHigh = bar.high
|
self.intraTradeHigh = bar.high
|
||||||
self.intraTradeLow = bar.low
|
self.intraTradeLow = bar.low
|
||||||
@ -215,4 +193,5 @@ class BollChannelStrategy(CtaTemplate):
|
|||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def onStopOrder(self, so):
|
def onStopOrder(self, so):
|
||||||
"""停止单推送"""
|
"""停止单推送"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user