This commit is contained in:
msincenselee 2015-11-05 00:28:41 +08:00
parent 43f96b1e05
commit f28351a242
6 changed files with 541 additions and 8 deletions

1
vn.strategy/__init__.py Normal file
View File

@ -0,0 +1 @@
__author__ = 'Incense'

View File

@ -0,0 +1 @@
__author__ = 'Incense'

View File

@ -0,0 +1,48 @@
# encoding: UTF-8
from strategyEngine import *
from backtestingEngine import *
from stratetyProduceBar import StrategyProduceBar
import decimal
def main():
"""回测程序主函数"""
# symbol = 'IF1506'
symbol = 'a'
# 创建回测引擎
be = BacktestingEngine()
# 创建策略引擎对象
se = StrategyEngine(be.eventEngine, be, backtesting=True)
be.setStrategyEngine(se)
# 初始化回测引擎
# be.connectMongo()
be.connectMysql()
# be.loadMongoDataHistory(symbol, datetime(2015,5,1), datetime.today())
# be.loadMongoDataHistory(symbol, datetime(2012,1,9), datetime(2012,1,14))
be.setDataHistory(symbol, datetime(2012,1,1), datetime(2012,1,31))
# 创建策略对象
setting = {}
#setting['fastAlpha'] = 0.2
#setting['slowAlpha'] = 0.05
# setting['startDate'] = datetime(year=2015, month=5, day=20)
setting['startDate'] = datetime(year=2012, month=1, day=1)
se.createStrategy(u'生成M1M5数据策略', symbol, StrategyProduceBar, setting)
# 启动所有策略
se.startAll()
# 开始回测
be.startBacktesting()
# 回测脚本
if __name__ == '__main__':
main()

View File

@ -24,16 +24,16 @@ import cPickle
OFFSET_OPEN = '0' # 开仓
OFFSET_CLOSE = '1' # 平仓
DIRECTION_BUY = '0' # 买入
DIRECTION_SELL = '1' # 卖出
DIRECTION_BUY = '0' # 买入
DIRECTION_SELL = '1' # 卖出
PRICETYPE_LIMIT = '2' # 限价
# buy 买入开仓 : DIRECTION_BUY = '0' OFFSET_OPEN = '0'
# sell 卖出平仓 : DIRECTION_SELL = '1' OFFSET_CLOSE = '1'
# buy 买入开仓 开多 : DIRECTION_BUY = '0' OFFSET_OPEN = '0'
# sell 卖出平仓 平多 : DIRECTION_SELL = '1' OFFSET_CLOSE = '1'
# short 卖出开仓 : DIRECTION_SELL = '1' OFFSET_OPEN = '0'
# cover 买入平仓 : DIRECTION_BUY = '0' OFFSET_CLOSE = '1'
# short 卖出开仓 开空 : DIRECTION_SELL = '1' OFFSET_OPEN = '0'
# cover 买入平仓 平空 : DIRECTION_BUY = '0' OFFSET_CLOSE = '1'
########################################################################
class Tick:
@ -400,6 +400,72 @@ class StrategyEngine(object):
else:
return None
#----------------------------------------------------------------------
def loadBarFromMysql(self, symbol, startDate, endDate, barType='M5'):
"""从MysqlDB中读取Bar数据
startDate,包含开始日期
endDate 包含结束日期
barType ='M1' K线类型,1分钟线
barType ='M5' K线类型5分钟线
barType = 'D1' K线类型日线
"""
if self.__mysqlConnected:
#获取指针
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
if endDate:
# 指定开始与结束日期
sqlstring = 'select open,high,low,close,volume,date,time,datetime from TB_{0}{1} ' \
'where date between cast(\'{2}\' as date) and cast(\'{3}\' as date) ' \
'order by datetime'.format(symbol, barType, startDate, endDate)
elif startDate:
# 指定开始日期
sqlstring = 'select open,high,low,close,volume,date,time,datetime from TB_{0}{1} ' \
'where date > cast(\'{2}\' as date) order by datetime'.format(symbol,
barType, startDate)
else:
# 没有指定,所有日期数据
sqlstring = 'select open,high,low,close,volume,date,time,datetime from TB_{0}{1} ' \
' order by datetime'.format(symbol, barType)
print sqlstring
count = cur.execute(sqlstring)
# cx = cur.fetchall()
fetch_counts = 0
cx = None
fetch_size = 1000
while True:
results = cur.fetchmany(fetch_size)
if not results:
break
if fetch_counts == 0:
cx = results
else:
cx = cx + results
fetch_counts = fetch_counts+fetch_size
print u'历史{0}Bar数据载入{1}'.format(barType, fetch_counts)
self.writeLog(u'历史{0}Bar数据载入完成{1}~{2},共{3}'.format(barType,startDate, endDate, count))
print u'策略引擎:历史{0}Bar数据载入完成{1}~{2},共{3}'.format(barType,startDate, endDate, count)
return cx
else:
return None
#----------------------------------------------------------------------
def getMysqlDeltaDate(self,symbol, startDate, decreaseDays):
"""从mysql获取交易日天数差"""
@ -494,8 +560,6 @@ class StrategyEngine(object):
if counts >= 3600:
self.__executeMysql(sql+values)
print u'写入{0}条Bar记录'.format(counts)

View File

@ -0,0 +1,397 @@
# encoding: UTF-8
# 首先写系统内置模块
import sys
print u'demoStrategy.py import sys success'
from datetime import datetime, timedelta, time, date
print u'demoStrategy.py import datetime.datetime/timedelta/time success'
from time import sleep
print u'demoStrategy.py import time.sleep success'
# 然后是自己编写的模块
from demoEngine import MainEngine
print u'demoStrategy.py import demoEngine.MainEngine success'
from strategyEngine import *
import vtConstant
import tushare as ts
import pandas as pd
import talib as ta
import MySQLdb
import os
import sys
import cPickle
########################################################################
class StrategyProduceBar(StrategyTemplate):
"""生成Bar线策略
"""
#----------------------------------------------------------------------
def __init__(self, name, symbol, engine):
"""Constructor"""
super(StrategyProduceBar, self).__init__(name, symbol, engine)
# 主连标签
if len(symbol) > 4:
self.symbolMi = symbol[:-4]
else:
self.symbolMi = symbol
# 最新TICK数据市场报价
self.currentTick = None
# M1 K线缓存对象
self.barOpen = EMPTY_FLOAT
self.barHigh = EMPTY_FLOAT
self.barLow = EMPTY_FLOAT
self.barClose = EMPTY_FLOAT
self.barVolume = EMPTY_INT
self.barTime = None
# M5 K线 计算数据
self.barM5Open = EMPTY_FLOAT
self.barM5High = EMPTY_FLOAT
self.barM5Low = EMPTY_FLOAT
self.barM5Close = EMPTY_FLOAT
self.barM5Volume = EMPTY_INT
self.barM5Time = None
# 当前交易日日期
self.curDate = None
# 仓位状态
self.pos = 0 # 0表示没有仓位1表示持有多头 -1表示持有空头
# 报单代码列表
self.listOrderRef = [] # 报单号列表
self.listStopOrder = [] # 停止单对象列表
# 是否完成了初始化
self.initCompleted = False
# 初始化时读取的历史数据的起始日期(可以选择外部设置)
self.startDate = None
self.lineM1Bar = [] # M1 K线数据
self.lineM5Bar = [] # M5 K线数据
#----------------------------------------------------------------------
def loadSetting(self, setting):
"""读取参数设定"""
try:
if setting['orderVolume']:
self.orderVolume = setting['orderVolume']
if setting['refDays']:
self.refDays = setting['refDays']
self.engine.writeLog(self.name + u'读取参数成功')
except KeyError:
self.engine.writeLog(self.name + u'读取参数设定出错,请检查参数字典')
try:
self.initStrategy(setting['startDate'])
except KeyError:
self.initStrategy()
#----------------------------------------------------------------------
def initStrategy(self, startDate=None):
"""初始化"""
# 获取 InputP个周期的5分钟Bar线数据初始化加载入M5Bar
self.initCompleted = True
self.engine.writeLog(self.name + u'初始化完成')
def __initNewDate(self, symbol, endDate=datetime.today()):
"""初始化新的一天
1清除多余的M1Bar M5Bar
2如果隔夜持仓需要继续清除
"""
#----------------------------------------------------------------------
def onTick(self, tick):
"""行情更新
:type tick: object
"""
# 保存最新的TICK
self.currentTick = tick
# 首先生成datetime.time格式的时间便于比较
# ticktime = self.strToTime(tick.time, tick.ms)
ticktime = tick.time
tickDate = date(ticktime.year, ticktime.month, ticktime.day)
if tickDate != self.curDate:
# 更新为新的一天
self.curDate = tickDate
# 初始化交易日数据
self.__initNewDate(self.symbol, self.curDate)
# 假设是收到的第一个TICK
if self.barOpen == 0:
# 初始化新的K线数据
self.barOpen = tick.lastPrice
self.barHigh = tick.lastPrice
self.barLow = tick.lastPrice
self.barClose = tick.lastPrice
self.barVolume = tick.volume
self.barTime = ticktime
else:
# 如果是当前一分钟内的数据
if ticktime.minute == self.barTime.minute and ticktime.hour == self.barTime.hour:
# 汇总TICK生成K线
self.barHigh = max(self.barHigh, tick.lastPrice)
self.barLow = min(self.barLow, tick.lastPrice)
self.barClose = tick.lastPrice
self.barVolume = self.barVolume + tick.volume
self.barTime = ticktime
# 如果是新一分钟的数据
else:
# 首先推送K线数据
self.onBar(self.barOpen, self.barHigh, self.barLow, self.barClose,
self.barVolume, self.barTime)
# 初始化新的K线数据
self.barOpen = tick.lastPrice
self.barHigh = tick.lastPrice
self.barLow = tick.lastPrice
self.barClose = tick.lastPrice
self.barVolume = tick.volume
self.barTime = ticktime
#----------------------------------------------------------------------
def onTrade(self, trade):
"""交易更新"""
log = self.name + u'当前持仓:' + str(self.pos)
print log
self.engine.writeLog(log)
#----------------------------------------------------------------------
def onOrder(self, order):
"""报单更新"""
pass
#----------------------------------------------------------------------
def onStopOrder(self, orderRef):
"""停止单更新"""
pass
#----------------------------------------------------------------------
def onBar(self, o, h, l, c, volume, t):
"""K线数据更新同时进行策略的买入、卖出逻辑计算"""
bartime = datetime(t.year, t.month, t.day, t.hour, t.minute) # 秒可以去除
# 保存M1-K线数据
bar = Bar()
bar.symbol = self.symbol
bar.open = o
bar.high = h
bar.low = l
bar.close = c
bar.volume = volume
bar.date = bartime.strftime('%Y-%m-%d')
bar.time = bartime.strftime('%H:%M:%S')
bar.datetime = bartime
self.lineM1Bar.append(bar)
# 保存M5-K线数据更新前置条件值
self.onM5Bar(o, h, l, c, volume, t)
# 交易逻辑
if self.initCompleted: # 首先检查是否是实盘运行还是数据预处理阶段
pass
# 记录日志
log = self.name + u'K线时间' + str(t) + '\n'
self.engine.writeLog(log)
#----------------------------------------------------------------------
def onM5Bar(self, o, h, l, c, volume, t):
"""更新5分钟K线
此方法有两个入口一个是OnBar推送的每分钟K线
另一个是initStrategy推送的初始化前若干周期的M5 K线"""
minute = t.minute - t.minute % 5
bartime = datetime(t.year,t.month,t.day,t.hour,minute)
# 如果 M5为空创建一个M5
if len(self.lineM5Bar) == 0:
m5bar = Bar()
m5bar.symbol = self.symbol
m5bar.open = o
m5bar.high = h
m5bar.low = l
m5bar.close = c
m5bar.volume = volume
m5bar.date = bartime.strftime('%Y-%m-%d')
m5bar.time = bartime.strftime('%H:%M:%S')
m5bar.datetime = bartime
self.lineM5Bar.append(m5bar)
else:
lastM5Bar = self.lineM5Bar[-1]
if(t-lastM5Bar.datetime).seconds < 300:
# 如果 新Bar数据的时间为同一个M5周期更新M5数据
lastM5Bar.high = max(lastM5Bar.high, h)
lastM5Bar.low = min(lastM5Bar.low, l)
lastM5Bar.close = c
lastM5Bar.volume = lastM5Bar.volume + volume
else:
# 如果 新Bar数据的周期为下一个M5周期重新计算PreM5系列值创建新的M5
m5bar = Bar()
m5bar.symbol = self.symbol
m5bar.open = o
m5bar.high = h
m5bar.low = l
m5bar.close = c
m5bar.volume = volume
m5bar.date = bartime.strftime('%Y-%m-%d')
m5bar.time = bartime.strftime('%H:%M:%S')
m5bar.datetime = bartime
self.lineM5Bar.append(m5bar)
#----------------------------------------------------------------------
def strToTime(self, t, ms):
"""从字符串时间转化为time格式的时间"""
hh, mm, ss = t.split(':')
tt = time(int(hh), int(mm), int(ss), microsecond=ms)
return tt
#----------------------------------------------------------------------
def saveData(self, id):
"""保存过程数据"""
# 保存K线
print u'{0}保存K线'.format(self.name)
self.__saveBarToMysql('M1', self.lineM1Bar)
self.__saveBarToMysql('M5', self.lineM5Bar)
def __saveBarToMysql(self,barType, barList):
"""
保存K线数据到数据库
id 回测ID
barList 对象为Bar的列表
"""
# 保存本地pickle文件
resultPath=os.getcwd()+'/result'
if not os.path.isdir(resultPath):
os.mkdir(resultPath)
resultFile = u'{0}/{1}_{2}Bar.pickle'.format(resultPath,self.symbolMi, barType)
cache= open(resultFile, mode='w')
cPickle.dump(barList,cache)
cache.close()
# 保存数据库
self.__connectMysql()
if self.__mysqlConnected:
sql = 'insert ignore into stockcn.TB_{0}{1} ' \
'( open,high, low,close,date,time,datetime,volume) ' \
'values '.format(self.symbolMi,barType)
values = ''
print u'{0}{1}Bar记录.'.format(len(barList),barType)
if len(barList) == 0:
return
counts = 0
for bar in barList:
if len(values) > 0:
values = values + ','
values = values + '({0},{1},{2},{3},\'{4}\',\'{5}\',\'{6}\',{7})'.format(
bar.open,
bar.high,
bar.low,
bar.close,
bar.date,
bar.time,
bar.datetime.strftime('%Y-%m-%d %H:%M:%S'),
bar.volume
)
if counts >= 3600:
self.__executeMysql(sql+values)
print u'写入{0}{1}Bar记录'.format(counts,barType)
counts = 0
values = ''
else:
counts = counts + 1
if counts > 0:
self.__executeMysql(sql+values)
print u'写入{0}{1}Bar记录'.format(counts,barType)
#----------------------------------------------------------------------
def __connectMysql(self):
"""连接MysqlDB"""
try:
self.__mysqlConnection = MySQLdb.connect(host='vnpy.cloudapp.net', user='stockcn', passwd='7uhb*IJN', db='stockcn', port=3306)
self.__mysqlConnected = True
print u'策略连接MysqlDB成功'
except ConnectionFailure:
print u'策略连接MysqlDB失败'
#----------------------------------------------------------------------
def __executeMysql(self, sql):
"""执行mysql语句"""
if not self.__mysqlConnected:
self.__connectMysql()
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
try:
cur.execute(sql)
self.__mysqlConnection.commit()
except Exception, e:
print e
print sql
self.__connectMysql()
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
cur.execute(sql)
self.__mysqlConnection.commit()

View File

@ -0,0 +1,22 @@
#coding=utf8
import tushare as ts
import pymongo
import json
#import csv
import time
#conn = pymongo.Connection('222.73.15.21', port=7017)
i=0
while i < 10000:
print time.localtime
i=i+1
code = "%06d" % i
print code
df = ts.get_hist_data(code)
#conn.db.DayAll.insert(json.loads(df.to_json(orient='index')))