local cache
This commit is contained in:
parent
e48ba99820
commit
1324e5ab2d
@ -9,6 +9,9 @@ from pymongo.errors import *
|
||||
from datetime import datetime, timedelta, time
|
||||
|
||||
from strategyEngine import *
|
||||
import sys
|
||||
import os
|
||||
import cPickle
|
||||
|
||||
|
||||
|
||||
@ -117,6 +120,71 @@ class BacktestingEngine(object):
|
||||
self.writeLog(u'回测引擎连接MysqlDB成功')
|
||||
except ConnectionFailure:
|
||||
self.writeLog(u'回测引擎连接MysqlDB失败')
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def loadDataHistory(self, symbol, startDate, endDate):
|
||||
"""载入历史TICK数据,"""
|
||||
if not endDate:
|
||||
endDate = datetime.today()
|
||||
|
||||
# 看本地缓存是否存在
|
||||
if self.__loadDataHistoryFromLocalCache(symbol, startDate, endDate):
|
||||
self.writeLog(u'历史TICK数据从Cache载入')
|
||||
return
|
||||
|
||||
intervalDays = 10
|
||||
|
||||
for i in range (0,(endDate - startDate).days +1, intervalDays):
|
||||
d1 = startDate + timedelta(days = i )
|
||||
|
||||
if (endDate - d1).days > 10:
|
||||
d2 = startDate + timedelta(days = i + intervalDays -1 )
|
||||
else:
|
||||
d2 = endDate
|
||||
|
||||
self.loadMysqlDataHistory(symbol, d1, d2)
|
||||
|
||||
self.writeLog(u'历史TICK数据共载入{0}条'.format(len(self.listDataHistory)))
|
||||
self.__saveDataHistoryToLocalCache(symbol, startDate, endDate)
|
||||
|
||||
|
||||
def __loadDataHistoryFromLocalCache(self, symbol, startDate, endDate):
|
||||
"""看本地缓存是否存在"""
|
||||
|
||||
cacheFolder = os.getcwd()+'\\cache'
|
||||
|
||||
|
||||
cacheFile = u'{0}\\{1}_{2}_{3}.pickle'.format(cacheFolder,symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
|
||||
|
||||
if not os.path.isfile(cacheFile):
|
||||
return False
|
||||
|
||||
else:
|
||||
cache = open(cacheFile,mode='r')
|
||||
self.listDataHistory = cPickle.load(cache)
|
||||
return True
|
||||
|
||||
def __saveDataHistoryToLocalCache(self, symbol, startDate, endDate):
|
||||
"""保存本地缓存"""
|
||||
cacheFolder = os.getcwd()+'\\cache'
|
||||
|
||||
if not os.path.isdir(cacheFolder):
|
||||
os.mkdir(cacheFolder)
|
||||
|
||||
cacheFile = u'{0}\\{1}_{2}_{3}.pickle'.format(cacheFolder,symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
|
||||
|
||||
if os.path.isfile(cacheFile):
|
||||
return False
|
||||
|
||||
else:
|
||||
cache= open(cacheFile, mode='w')
|
||||
|
||||
cPickle.dump(self.listDataHistory,cache)
|
||||
|
||||
cache.close()
|
||||
|
||||
return True
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def loadMysqlDataHistory(self, symbol, startDate, endDate):
|
||||
"""从Mysql载入历史TICK数据,"""
|
||||
@ -125,6 +193,7 @@ class BacktestingEngine(object):
|
||||
|
||||
if self.__mysqlConnected:
|
||||
|
||||
|
||||
#获取指针
|
||||
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
|
||||
|
||||
@ -154,7 +223,7 @@ class BacktestingEngine(object):
|
||||
self.writeLog(sqlstring)
|
||||
|
||||
count = cur.execute(sqlstring)
|
||||
self.writeLog(u'历史TICK数据共{0}条'.format(count))
|
||||
#self.writeLog(u'历史TICK数据共{0}条'.format(count))
|
||||
|
||||
# 将TICK数据读入内存
|
||||
#self.listDataHistory = cur.fetchall()
|
||||
@ -168,7 +237,7 @@ class BacktestingEngine(object):
|
||||
if not results:
|
||||
break
|
||||
|
||||
fetch_counts = fetch_counts+fetch_size
|
||||
fetch_counts = fetch_counts + len(results)
|
||||
|
||||
if not self.listDataHistory:
|
||||
|
||||
@ -177,9 +246,8 @@ class BacktestingEngine(object):
|
||||
else:
|
||||
self.listDataHistory = self.listDataHistory + results
|
||||
|
||||
self.writeLog(u'历史TICK数据载入{0}条'.format(fetch_counts))
|
||||
self.writeLog(u'{1}~{2}历史TICK数据载入共{0}条'.format(fetch_counts,startDate,endDate))
|
||||
|
||||
self.writeLog(u'历史TICK数据载入完成,{1}~{2},共{0}条'.format(count,startDate,endDate))
|
||||
|
||||
else:
|
||||
self.writeLog(u'MysqlDB未连接,请检查')
|
||||
@ -311,10 +379,14 @@ class BacktestingEngine(object):
|
||||
event.dict_['data'] = data
|
||||
self.strategyEngine.updateMarketData(event)
|
||||
|
||||
# 保存交易到本地结果
|
||||
self.saveTradeData()
|
||||
|
||||
# 保存交易到数据库中
|
||||
self.saveTradeDataToMysql()
|
||||
|
||||
|
||||
|
||||
t2 = datetime.now()
|
||||
|
||||
self.writeLog(u'回测结束,{0},耗时:{1}秒'.format(str(t2),(t2-t1).seconds))
|
||||
@ -370,15 +442,33 @@ class BacktestingEngine(object):
|
||||
#----------------------------------------------------------------------
|
||||
def saveTradeData(self):
|
||||
"""保存交易记录"""
|
||||
f = shelve.open('result.vn')
|
||||
f['listTrade'] = self.listTrade
|
||||
f.close()
|
||||
#f = shelve.open('result.vn')
|
||||
#f['listTrade'] = self.listTrade
|
||||
#f.close()
|
||||
|
||||
# 保存本地pickle文件
|
||||
resultPath=os.getcwd()+'\\result'
|
||||
|
||||
if not os.path.isdir(resultPath):
|
||||
os.mkdir(resultPath)
|
||||
|
||||
resultFile = u'{0}\\{1}_Trade.pickle'.format(resultPath, self.Id)
|
||||
|
||||
cache= open(resultFile, mode='w')
|
||||
|
||||
cPickle.dump(self.listTrade,cache)
|
||||
|
||||
cache.close()
|
||||
|
||||
"""仿真订阅合约"""
|
||||
pass
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def saveTradeDataToMysql(self):
|
||||
"""保存交易记录到mysql,added by Incense Lee"""
|
||||
|
||||
self.connectMysql()
|
||||
|
||||
if self.__mysqlConnected:
|
||||
sql='insert into BackTest.TB_Trade (Id,symbol,orderRef,tradeID,direction,offset,price,volume,tradeTime,amount) values '
|
||||
values = ''
|
||||
|
@ -8,7 +8,7 @@ import decimal
|
||||
def main():
|
||||
"""回测程序主函数"""
|
||||
# symbol = 'IF1506'
|
||||
symbol = 'ag'
|
||||
symbol = 'a'
|
||||
|
||||
# 创建回测引擎
|
||||
be = BacktestingEngine()
|
||||
@ -22,14 +22,14 @@ def main():
|
||||
be.connectMysql()
|
||||
# be.loadMongoDataHistory(symbol, datetime(2015,5,1), datetime.today())
|
||||
# be.loadMongoDataHistory(symbol, datetime(2012,1,9), datetime(2012,1,14))
|
||||
be.loadMysqlDataHistory(symbol, datetime(2012,6,9), datetime(2012,7,20))
|
||||
be.loadDataHistory(symbol, datetime(2012,1,1), datetime(2012,1,30))
|
||||
|
||||
# 创建策略对象
|
||||
setting = {}
|
||||
setting['fastAlpha'] = 0.2
|
||||
setting['slowAlpha'] = 0.05
|
||||
# setting['startDate'] = datetime(year=2015, month=5, day=20)
|
||||
setting['startDate'] = datetime(year=2012, month=6, day=9)
|
||||
setting['startDate'] = datetime(year=2012, month=1, day=1)
|
||||
|
||||
se.createStrategy(u'EMA演示策略', symbol, SimpleEmaStrategy, setting)
|
||||
|
||||
|
@ -16,6 +16,9 @@ from vtConstant import *
|
||||
|
||||
import MySQLdb
|
||||
|
||||
import os
|
||||
import sys
|
||||
import cPickle
|
||||
|
||||
# 常量定义
|
||||
OFFSET_OPEN = '0' # 开仓
|
||||
@ -166,6 +169,8 @@ class Order(object):
|
||||
|
||||
self.status = '' # 报单状态代码
|
||||
|
||||
self.preTradeID = '' # 上一成交单编号(用于平仓)
|
||||
|
||||
|
||||
########################################################################
|
||||
class StopOrder(object):
|
||||
@ -436,6 +441,25 @@ class StrategyEngine(object):
|
||||
id, 回测ID
|
||||
barList, 对象为Bar的列表
|
||||
"""
|
||||
|
||||
# 保存本地pickle文件
|
||||
resultPath=os.getcwd()+'\\result'
|
||||
|
||||
if not os.path.isdir(resultPath):
|
||||
os.mkdir(resultPath)
|
||||
|
||||
resultFile = u'{0}\\{1}_Bar.pickle'.format(resultPath,id)
|
||||
|
||||
cache= open(resultFile, mode='w')
|
||||
|
||||
cPickle.dump(barList,cache)
|
||||
|
||||
cache.close()
|
||||
|
||||
# 保存数据库
|
||||
|
||||
self.__connectMysql()
|
||||
|
||||
if self.__mysqlConnected:
|
||||
sql='insert into BackTest.TB_Bar (Id, symbol ,open ,high ,low ,close ,date ,time ,datetime, volume, openInterest) values '
|
||||
values = ''
|
||||
@ -491,6 +515,23 @@ class StrategyEngine(object):
|
||||
保存EMA到数据库
|
||||
id,回测的编号
|
||||
"""
|
||||
|
||||
# 保存本地pickle文件
|
||||
resultPath=os.getcwd()+'\\result'
|
||||
|
||||
if not os.path.isdir(resultPath):
|
||||
os.mkdir(resultPath)
|
||||
|
||||
resultFile = u'{0}\\{1}_Ema.pickle'.format(resultPath, id)
|
||||
|
||||
cache= open(resultFile, mode='w')
|
||||
|
||||
cPickle.dump(emaList,cache)
|
||||
|
||||
cache.close()
|
||||
|
||||
self.__connectMysql()
|
||||
|
||||
if self.__mysqlConnected:
|
||||
sql='insert into BackTest.TB_Ema (Id, symbol ,fastEMA,slowEMA ,date ,time ,datetime) values '
|
||||
values = ''
|
||||
|
17
vn.training/day_test.py
Normal file
17
vn.training/day_test.py
Normal file
@ -0,0 +1,17 @@
|
||||
# encoding: UTF-8
|
||||
|
||||
import datetime
|
||||
|
||||
startDate = datetime.date(2012,6,1)
|
||||
endDate = datetime.date(2012,12,31)
|
||||
for i in range (0,(endDate - startDate).days +1, 10):
|
||||
d1 = startDate + datetime.timedelta(days= i )
|
||||
|
||||
if (endDate - d1).days > 10:
|
||||
d2 = startDate + datetime.timedelta(days= i+9 )
|
||||
else:
|
||||
d2 = endDate
|
||||
|
||||
i = i + 10
|
||||
|
||||
print str(d1),str(d2)
|
@ -951,7 +951,7 @@ if __name__ == '__main__':
|
||||
|
||||
程序目录= os.getcwd()
|
||||
|
||||
pfile= open(file=os.path.join(程序目录, '绘图数据.pickle'), mode='rb')
|
||||
pfile= open(file=os.path.join(程序目录, 'draw_data.pickle'), mode='rb')
|
||||
绘图数据= pickle.load(pfile)
|
||||
pfile.close()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user