20151105
This commit is contained in:
parent
43f96b1e05
commit
f28351a242
1
vn.strategy/__init__.py
Normal file
1
vn.strategy/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
__author__ = 'Incense'
|
1
vn.strategy/strategydemo/__init__.py
Normal file
1
vn.strategy/strategydemo/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
__author__ = 'Incense'
|
48
vn.strategy/strategydemo/produceBarData.py
Normal file
48
vn.strategy/strategydemo/produceBarData.py
Normal file
@ -0,0 +1,48 @@
|
||||
# encoding: UTF-8
|
||||
|
||||
from strategyEngine import *
|
||||
from backtestingEngine import *
|
||||
from stratetyProduceBar import StrategyProduceBar
|
||||
import decimal
|
||||
|
||||
def main():
|
||||
"""回测程序主函数"""
|
||||
# symbol = 'IF1506'
|
||||
symbol = 'a'
|
||||
|
||||
# 创建回测引擎
|
||||
be = BacktestingEngine()
|
||||
|
||||
# 创建策略引擎对象
|
||||
se = StrategyEngine(be.eventEngine, be, backtesting=True)
|
||||
be.setStrategyEngine(se)
|
||||
|
||||
# 初始化回测引擎
|
||||
# be.connectMongo()
|
||||
be.connectMysql()
|
||||
# be.loadMongoDataHistory(symbol, datetime(2015,5,1), datetime.today())
|
||||
# be.loadMongoDataHistory(symbol, datetime(2012,1,9), datetime(2012,1,14))
|
||||
|
||||
be.setDataHistory(symbol, datetime(2012,1,1), datetime(2012,1,31))
|
||||
|
||||
# 创建策略对象
|
||||
setting = {}
|
||||
#setting['fastAlpha'] = 0.2
|
||||
#setting['slowAlpha'] = 0.05
|
||||
# setting['startDate'] = datetime(year=2015, month=5, day=20)
|
||||
setting['startDate'] = datetime(year=2012, month=1, day=1)
|
||||
|
||||
se.createStrategy(u'生成M1M5数据策略', symbol, StrategyProduceBar, setting)
|
||||
|
||||
# 启动所有策略
|
||||
se.startAll()
|
||||
|
||||
# 开始回测
|
||||
be.startBacktesting()
|
||||
|
||||
|
||||
# 回测脚本
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
@ -24,16 +24,16 @@ import cPickle
|
||||
OFFSET_OPEN = '0' # 开仓
|
||||
OFFSET_CLOSE = '1' # 平仓
|
||||
|
||||
DIRECTION_BUY = '0' # 买入
|
||||
DIRECTION_SELL = '1' # 卖出
|
||||
DIRECTION_BUY = '0' # 买入 多
|
||||
DIRECTION_SELL = '1' # 卖出 空
|
||||
|
||||
PRICETYPE_LIMIT = '2' # 限价
|
||||
|
||||
# buy 买入开仓 : DIRECTION_BUY = '0' OFFSET_OPEN = '0'
|
||||
# sell 卖出平仓 : DIRECTION_SELL = '1' OFFSET_CLOSE = '1'
|
||||
# buy 买入开仓 开多 : DIRECTION_BUY = '0' OFFSET_OPEN = '0'
|
||||
# sell 卖出平仓 平多 : DIRECTION_SELL = '1' OFFSET_CLOSE = '1'
|
||||
|
||||
# short 卖出开仓 : DIRECTION_SELL = '1' OFFSET_OPEN = '0'
|
||||
# cover 买入平仓 : DIRECTION_BUY = '0' OFFSET_CLOSE = '1'
|
||||
# short 卖出开仓 开空 : DIRECTION_SELL = '1' OFFSET_OPEN = '0'
|
||||
# cover 买入平仓 平空 : DIRECTION_BUY = '0' OFFSET_CLOSE = '1'
|
||||
|
||||
########################################################################
|
||||
class Tick:
|
||||
@ -400,6 +400,72 @@ class StrategyEngine(object):
|
||||
else:
|
||||
return None
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def loadBarFromMysql(self, symbol, startDate, endDate, barType='M5'):
|
||||
"""从MysqlDB中读取Bar数据
|
||||
startDate,(包含)开始日期
|
||||
endDate, (包含)结束日期
|
||||
barType ='M1' K线类型,1分钟线
|
||||
barType ='M5' K线类型,5分钟线
|
||||
barType = 'D1' K线类型,日线
|
||||
"""
|
||||
|
||||
if self.__mysqlConnected:
|
||||
|
||||
#获取指针
|
||||
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
|
||||
|
||||
if endDate:
|
||||
# 指定开始与结束日期
|
||||
sqlstring = 'select open,high,low,close,volume,date,time,datetime from TB_{0}{1} ' \
|
||||
'where date between cast(\'{2}\' as date) and cast(\'{3}\' as date) ' \
|
||||
'order by datetime'.format(symbol, barType, startDate, endDate)
|
||||
|
||||
elif startDate:
|
||||
# 指定开始日期
|
||||
sqlstring = 'select open,high,low,close,volume,date,time,datetime from TB_{0}{1} ' \
|
||||
'where date > cast(\'{2}\' as date) order by datetime'.format(symbol,
|
||||
barType, startDate)
|
||||
|
||||
else:
|
||||
# 没有指定,所有日期数据
|
||||
sqlstring = 'select open,high,low,close,volume,date,time,datetime from TB_{0}{1} ' \
|
||||
' order by datetime'.format(symbol, barType)
|
||||
|
||||
print sqlstring
|
||||
|
||||
count = cur.execute(sqlstring)
|
||||
|
||||
# cx = cur.fetchall()
|
||||
fetch_counts = 0
|
||||
|
||||
cx = None
|
||||
|
||||
fetch_size = 1000
|
||||
|
||||
while True:
|
||||
results = cur.fetchmany(fetch_size)
|
||||
|
||||
if not results:
|
||||
break
|
||||
|
||||
if fetch_counts == 0:
|
||||
cx = results
|
||||
else:
|
||||
cx = cx + results
|
||||
|
||||
fetch_counts = fetch_counts+fetch_size
|
||||
|
||||
print u'历史{0}Bar数据载入{1}条'.format(barType, fetch_counts)
|
||||
|
||||
self.writeLog(u'历史{0}Bar数据载入完成,{1}~{2},共{3}条'.format(barType,startDate, endDate, count))
|
||||
|
||||
print u'策略引擎:历史{0}Bar数据载入完成,{1}~{2},共{3}条'.format(barType,startDate, endDate, count)
|
||||
|
||||
return cx
|
||||
else:
|
||||
return None
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def getMysqlDeltaDate(self,symbol, startDate, decreaseDays):
|
||||
"""从mysql获取交易日天数差"""
|
||||
@ -494,8 +560,6 @@ class StrategyEngine(object):
|
||||
|
||||
if counts >= 3600:
|
||||
|
||||
|
||||
|
||||
self.__executeMysql(sql+values)
|
||||
|
||||
print u'写入{0}条Bar记录'.format(counts)
|
||||
|
397
vn.strategy/strategydemo/stratetyProduceBar.py
Normal file
397
vn.strategy/strategydemo/stratetyProduceBar.py
Normal file
@ -0,0 +1,397 @@
|
||||
# encoding: UTF-8
|
||||
|
||||
# 首先写系统内置模块
|
||||
import sys
|
||||
print u'demoStrategy.py import sys success'
|
||||
from datetime import datetime, timedelta, time, date
|
||||
print u'demoStrategy.py import datetime.datetime/timedelta/time success'
|
||||
from time import sleep
|
||||
print u'demoStrategy.py import time.sleep success'
|
||||
|
||||
|
||||
# 然后是自己编写的模块
|
||||
from demoEngine import MainEngine
|
||||
print u'demoStrategy.py import demoEngine.MainEngine success'
|
||||
from strategyEngine import *
|
||||
import vtConstant
|
||||
|
||||
import tushare as ts
|
||||
import pandas as pd
|
||||
import talib as ta
|
||||
|
||||
import MySQLdb
|
||||
|
||||
import os
|
||||
import sys
|
||||
import cPickle
|
||||
|
||||
|
||||
########################################################################
|
||||
class StrategyProduceBar(StrategyTemplate):
|
||||
"""生成Bar线策略
|
||||
|
||||
"""
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self, name, symbol, engine):
|
||||
"""Constructor"""
|
||||
|
||||
super(StrategyProduceBar, self).__init__(name, symbol, engine)
|
||||
|
||||
# 主连标签
|
||||
if len(symbol) > 4:
|
||||
self.symbolMi = symbol[:-4]
|
||||
else:
|
||||
self.symbolMi = symbol
|
||||
|
||||
|
||||
# 最新TICK数据(市场报价)
|
||||
self.currentTick = None
|
||||
|
||||
# M1 K线缓存对象
|
||||
self.barOpen = EMPTY_FLOAT
|
||||
self.barHigh = EMPTY_FLOAT
|
||||
self.barLow = EMPTY_FLOAT
|
||||
self.barClose = EMPTY_FLOAT
|
||||
self.barVolume = EMPTY_INT
|
||||
self.barTime = None
|
||||
|
||||
# M5 K线 计算数据
|
||||
|
||||
self.barM5Open = EMPTY_FLOAT
|
||||
self.barM5High = EMPTY_FLOAT
|
||||
self.barM5Low = EMPTY_FLOAT
|
||||
self.barM5Close = EMPTY_FLOAT
|
||||
self.barM5Volume = EMPTY_INT
|
||||
self.barM5Time = None
|
||||
|
||||
# 当前交易日日期
|
||||
self.curDate = None
|
||||
|
||||
# 仓位状态
|
||||
self.pos = 0 # 0表示没有仓位,1表示持有多头, -1表示持有空头
|
||||
|
||||
# 报单代码列表
|
||||
self.listOrderRef = [] # 报单号列表
|
||||
self.listStopOrder = [] # 停止单对象列表
|
||||
|
||||
# 是否完成了初始化
|
||||
self.initCompleted = False
|
||||
|
||||
# 初始化时读取的历史数据的起始日期(可以选择外部设置)
|
||||
self.startDate = None
|
||||
|
||||
|
||||
self.lineM1Bar = [] # M1 K线数据
|
||||
self.lineM5Bar = [] # M5 K线数据
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def loadSetting(self, setting):
|
||||
"""读取参数设定"""
|
||||
try:
|
||||
if setting['orderVolume']:
|
||||
self.orderVolume = setting['orderVolume']
|
||||
|
||||
if setting['refDays']:
|
||||
self.refDays = setting['refDays']
|
||||
|
||||
self.engine.writeLog(self.name + u'读取参数成功')
|
||||
except KeyError:
|
||||
self.engine.writeLog(self.name + u'读取参数设定出错,请检查参数字典')
|
||||
|
||||
try:
|
||||
self.initStrategy(setting['startDate'])
|
||||
except KeyError:
|
||||
self.initStrategy()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def initStrategy(self, startDate=None):
|
||||
"""初始化"""
|
||||
|
||||
# 获取 InputP个周期的5分钟Bar线数据,初始化加载入M5Bar
|
||||
|
||||
self.initCompleted = True
|
||||
|
||||
self.engine.writeLog(self.name + u'初始化完成')
|
||||
|
||||
def __initNewDate(self, symbol, endDate=datetime.today()):
|
||||
"""初始化新的一天
|
||||
1、清除多余的M1Bar& M5Bar
|
||||
2、如果隔夜持仓,需要继续清除。
|
||||
|
||||
"""
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onTick(self, tick):
|
||||
"""行情更新
|
||||
:type tick: object
|
||||
"""
|
||||
# 保存最新的TICK
|
||||
self.currentTick = tick
|
||||
|
||||
# 首先生成datetime.time格式的时间(便于比较)
|
||||
# ticktime = self.strToTime(tick.time, tick.ms)
|
||||
ticktime = tick.time
|
||||
|
||||
tickDate = date(ticktime.year, ticktime.month, ticktime.day)
|
||||
|
||||
if tickDate != self.curDate:
|
||||
|
||||
# 更新为新的一天
|
||||
self.curDate = tickDate
|
||||
|
||||
# 初始化交易日数据
|
||||
self.__initNewDate(self.symbol, self.curDate)
|
||||
|
||||
|
||||
|
||||
# 假设是收到的第一个TICK
|
||||
if self.barOpen == 0:
|
||||
# 初始化新的K线数据
|
||||
self.barOpen = tick.lastPrice
|
||||
self.barHigh = tick.lastPrice
|
||||
self.barLow = tick.lastPrice
|
||||
self.barClose = tick.lastPrice
|
||||
self.barVolume = tick.volume
|
||||
self.barTime = ticktime
|
||||
else:
|
||||
# 如果是当前一分钟内的数据
|
||||
if ticktime.minute == self.barTime.minute and ticktime.hour == self.barTime.hour:
|
||||
# 汇总TICK生成K线
|
||||
self.barHigh = max(self.barHigh, tick.lastPrice)
|
||||
self.barLow = min(self.barLow, tick.lastPrice)
|
||||
self.barClose = tick.lastPrice
|
||||
self.barVolume = self.barVolume + tick.volume
|
||||
self.barTime = ticktime
|
||||
|
||||
# 如果是新一分钟的数据
|
||||
else:
|
||||
# 首先推送K线数据
|
||||
self.onBar(self.barOpen, self.barHigh, self.barLow, self.barClose,
|
||||
self.barVolume, self.barTime)
|
||||
|
||||
# 初始化新的K线数据
|
||||
self.barOpen = tick.lastPrice
|
||||
self.barHigh = tick.lastPrice
|
||||
self.barLow = tick.lastPrice
|
||||
self.barClose = tick.lastPrice
|
||||
self.barVolume = tick.volume
|
||||
self.barTime = ticktime
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onTrade(self, trade):
|
||||
"""交易更新"""
|
||||
|
||||
log = self.name + u'当前持仓:' + str(self.pos)
|
||||
|
||||
print log
|
||||
self.engine.writeLog(log)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onOrder(self, order):
|
||||
"""报单更新"""
|
||||
pass
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onStopOrder(self, orderRef):
|
||||
"""停止单更新"""
|
||||
pass
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onBar(self, o, h, l, c, volume, t):
|
||||
"""K线数据更新,同时进行策略的买入、卖出逻辑计算"""
|
||||
|
||||
bartime = datetime(t.year, t.month, t.day, t.hour, t.minute) # 秒可以去除
|
||||
|
||||
# 保存M1-K线数据
|
||||
bar = Bar()
|
||||
bar.symbol = self.symbol
|
||||
bar.open = o
|
||||
bar.high = h
|
||||
bar.low = l
|
||||
bar.close = c
|
||||
bar.volume = volume
|
||||
bar.date = bartime.strftime('%Y-%m-%d')
|
||||
bar.time = bartime.strftime('%H:%M:%S')
|
||||
bar.datetime = bartime
|
||||
self.lineM1Bar.append(bar)
|
||||
|
||||
# 保存M5-K线数据,更新前置条件值
|
||||
self.onM5Bar(o, h, l, c, volume, t)
|
||||
|
||||
# 交易逻辑
|
||||
if self.initCompleted: # 首先检查是否是实盘运行还是数据预处理阶段
|
||||
pass
|
||||
|
||||
# 记录日志
|
||||
log = self.name + u',K线时间:' + str(t) + '\n'
|
||||
self.engine.writeLog(log)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onM5Bar(self, o, h, l, c, volume, t):
|
||||
"""更新5分钟K线
|
||||
此方法有两个入口,一个是OnBar推送的每分钟K线,
|
||||
另一个是initStrategy推送的初始化前若干周期的M5 K线"""
|
||||
minute = t.minute - t.minute % 5
|
||||
bartime = datetime(t.year,t.month,t.day,t.hour,minute)
|
||||
|
||||
# 如果 M5为空,创建一个M5
|
||||
if len(self.lineM5Bar) == 0:
|
||||
m5bar = Bar()
|
||||
|
||||
m5bar.symbol = self.symbol
|
||||
m5bar.open = o
|
||||
m5bar.high = h
|
||||
m5bar.low = l
|
||||
m5bar.close = c
|
||||
m5bar.volume = volume
|
||||
m5bar.date = bartime.strftime('%Y-%m-%d')
|
||||
m5bar.time = bartime.strftime('%H:%M:%S')
|
||||
m5bar.datetime = bartime
|
||||
|
||||
self.lineM5Bar.append(m5bar)
|
||||
|
||||
else:
|
||||
lastM5Bar = self.lineM5Bar[-1]
|
||||
|
||||
if(t-lastM5Bar.datetime).seconds < 300:
|
||||
# 如果 新Bar数据的时间为同一个M5周期,更新M5数据
|
||||
lastM5Bar.high = max(lastM5Bar.high, h)
|
||||
lastM5Bar.low = min(lastM5Bar.low, l)
|
||||
lastM5Bar.close = c
|
||||
lastM5Bar.volume = lastM5Bar.volume + volume
|
||||
else:
|
||||
# 如果 新Bar数据的周期为下一个M5周期,重新计算PreM5系列值,创建新的M5
|
||||
m5bar = Bar()
|
||||
|
||||
m5bar.symbol = self.symbol
|
||||
m5bar.open = o
|
||||
m5bar.high = h
|
||||
m5bar.low = l
|
||||
m5bar.close = c
|
||||
m5bar.volume = volume
|
||||
m5bar.date = bartime.strftime('%Y-%m-%d')
|
||||
m5bar.time = bartime.strftime('%H:%M:%S')
|
||||
m5bar.datetime = bartime
|
||||
|
||||
self.lineM5Bar.append(m5bar)
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def strToTime(self, t, ms):
|
||||
"""从字符串时间转化为time格式的时间"""
|
||||
hh, mm, ss = t.split(':')
|
||||
tt = time(int(hh), int(mm), int(ss), microsecond=ms)
|
||||
return tt
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def saveData(self, id):
|
||||
"""保存过程数据"""
|
||||
# 保存K线
|
||||
print u'{0}保存K线'.format(self.name)
|
||||
self.__saveBarToMysql('M1', self.lineM1Bar)
|
||||
self.__saveBarToMysql('M5', self.lineM5Bar)
|
||||
|
||||
|
||||
def __saveBarToMysql(self,barType, barList):
|
||||
"""
|
||||
保存K线数据到数据库
|
||||
id, 回测ID
|
||||
barList, 对象为Bar的列表
|
||||
"""
|
||||
|
||||
# 保存本地pickle文件
|
||||
resultPath=os.getcwd()+'/result'
|
||||
|
||||
if not os.path.isdir(resultPath):
|
||||
os.mkdir(resultPath)
|
||||
|
||||
resultFile = u'{0}/{1}_{2}Bar.pickle'.format(resultPath,self.symbolMi, barType)
|
||||
|
||||
cache= open(resultFile, mode='w')
|
||||
|
||||
cPickle.dump(barList,cache)
|
||||
|
||||
cache.close()
|
||||
|
||||
# 保存数据库
|
||||
|
||||
self.__connectMysql()
|
||||
|
||||
if self.__mysqlConnected:
|
||||
sql = 'insert ignore into stockcn.TB_{0}{1} ' \
|
||||
'( open,high, low,close,date,time,datetime,volume) ' \
|
||||
'values '.format(self.symbolMi,barType)
|
||||
|
||||
values = ''
|
||||
|
||||
print u'共{0}条{1}Bar记录.'.format(len(barList),barType)
|
||||
|
||||
if len(barList) == 0:
|
||||
return
|
||||
|
||||
counts = 0
|
||||
|
||||
for bar in barList:
|
||||
|
||||
if len(values) > 0:
|
||||
values = values + ','
|
||||
|
||||
values = values + '({0},{1},{2},{3},\'{4}\',\'{5}\',\'{6}\',{7})'.format(
|
||||
bar.open,
|
||||
bar.high,
|
||||
bar.low,
|
||||
bar.close,
|
||||
bar.date,
|
||||
bar.time,
|
||||
bar.datetime.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
bar.volume
|
||||
)
|
||||
|
||||
if counts >= 3600:
|
||||
|
||||
self.__executeMysql(sql+values)
|
||||
|
||||
print u'写入{0}条{1}Bar记录'.format(counts,barType)
|
||||
|
||||
counts = 0
|
||||
values = ''
|
||||
|
||||
else:
|
||||
counts = counts + 1
|
||||
|
||||
if counts > 0:
|
||||
|
||||
self.__executeMysql(sql+values)
|
||||
print u'写入{0}条{1}Bar记录'.format(counts,barType)
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __connectMysql(self):
|
||||
"""连接MysqlDB"""
|
||||
try:
|
||||
self.__mysqlConnection = MySQLdb.connect(host='vnpy.cloudapp.net', user='stockcn', passwd='7uhb*IJN', db='stockcn', port=3306)
|
||||
self.__mysqlConnected = True
|
||||
print u'策略连接MysqlDB成功'
|
||||
except ConnectionFailure:
|
||||
print u'策略连接MysqlDB失败'
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __executeMysql(self, sql):
|
||||
"""执行mysql语句"""
|
||||
if not self.__mysqlConnected:
|
||||
self.__connectMysql()
|
||||
|
||||
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
|
||||
|
||||
try:
|
||||
cur.execute(sql)
|
||||
self.__mysqlConnection.commit()
|
||||
|
||||
except Exception, e:
|
||||
print e
|
||||
print sql
|
||||
|
||||
self.__connectMysql()
|
||||
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
|
||||
cur.execute(sql)
|
||||
self.__mysqlConnection.commit()
|
22
vn.training/TushareGetAllDay.py
Normal file
22
vn.training/TushareGetAllDay.py
Normal file
@ -0,0 +1,22 @@
|
||||
#coding=utf8
|
||||
|
||||
import tushare as ts
|
||||
import pymongo
|
||||
import json
|
||||
#import csv
|
||||
import time
|
||||
|
||||
#conn = pymongo.Connection('222.73.15.21', port=7017)
|
||||
|
||||
i=0
|
||||
|
||||
while i < 10000:
|
||||
print time.localtime
|
||||
i=i+1
|
||||
code = "%06d" % i
|
||||
print code
|
||||
df = ts.get_hist_data(code)
|
||||
|
||||
#conn.db.DayAll.insert(json.loads(df.to_json(orient='index')))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user