From 1324e5ab2d9be0b2ee9358b79bba482f3c8eccb8 Mon Sep 17 00:00:00 2001 From: msincenselee Date: Sat, 24 Oct 2015 01:12:08 +0800 Subject: [PATCH] local cache --- vn.strategy/strategydemo/backtestingEngine.py | 104 ++++++++++++++++-- vn.strategy/strategydemo/demoBacktesting.py | 6 +- vn.strategy/strategydemo/strategyEngine.py | 41 +++++++ vn.training/day_test.py | 17 +++ vn.training/draw.py | 2 +- 5 files changed, 159 insertions(+), 11 deletions(-) create mode 100644 vn.training/day_test.py diff --git a/vn.strategy/strategydemo/backtestingEngine.py b/vn.strategy/strategydemo/backtestingEngine.py index 0a164fbc..a30f51f2 100644 --- a/vn.strategy/strategydemo/backtestingEngine.py +++ b/vn.strategy/strategydemo/backtestingEngine.py @@ -9,6 +9,9 @@ from pymongo.errors import * from datetime import datetime, timedelta, time from strategyEngine import * +import sys +import os +import cPickle @@ -117,6 +120,71 @@ class BacktestingEngine(object): self.writeLog(u'回测引擎连接MysqlDB成功') except ConnectionFailure: self.writeLog(u'回测引擎连接MysqlDB失败') + + #---------------------------------------------------------------------- + def loadDataHistory(self, symbol, startDate, endDate): + """载入历史TICK数据,""" + if not endDate: + endDate = datetime.today() + + # 看本地缓存是否存在 + if self.__loadDataHistoryFromLocalCache(symbol, startDate, endDate): + self.writeLog(u'历史TICK数据从Cache载入') + return + + intervalDays = 10 + + for i in range (0,(endDate - startDate).days +1, intervalDays): + d1 = startDate + timedelta(days = i ) + + if (endDate - d1).days > 10: + d2 = startDate + timedelta(days = i + intervalDays -1 ) + else: + d2 = endDate + + self.loadMysqlDataHistory(symbol, d1, d2) + + self.writeLog(u'历史TICK数据共载入{0}条'.format(len(self.listDataHistory))) + self.__saveDataHistoryToLocalCache(symbol, startDate, endDate) + + + def __loadDataHistoryFromLocalCache(self, symbol, startDate, endDate): + """看本地缓存是否存在""" + + cacheFolder = os.getcwd()+'\\cache' + + + cacheFile = u'{0}\\{1}_{2}_{3}.pickle'.format(cacheFolder,symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d')) + + if not os.path.isfile(cacheFile): + return False + + else: + cache = open(cacheFile,mode='r') + self.listDataHistory = cPickle.load(cache) + return True + + def __saveDataHistoryToLocalCache(self, symbol, startDate, endDate): + """保存本地缓存""" + cacheFolder = os.getcwd()+'\\cache' + + if not os.path.isdir(cacheFolder): + os.mkdir(cacheFolder) + + cacheFile = u'{0}\\{1}_{2}_{3}.pickle'.format(cacheFolder,symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d')) + + if os.path.isfile(cacheFile): + return False + + else: + cache= open(cacheFile, mode='w') + + cPickle.dump(self.listDataHistory,cache) + + cache.close() + + return True + #---------------------------------------------------------------------- def loadMysqlDataHistory(self, symbol, startDate, endDate): """从Mysql载入历史TICK数据,""" @@ -125,6 +193,7 @@ class BacktestingEngine(object): if self.__mysqlConnected: + #获取指针 cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor) @@ -154,7 +223,7 @@ class BacktestingEngine(object): self.writeLog(sqlstring) count = cur.execute(sqlstring) - self.writeLog(u'历史TICK数据共{0}条'.format(count)) + #self.writeLog(u'历史TICK数据共{0}条'.format(count)) # 将TICK数据读入内存 #self.listDataHistory = cur.fetchall() @@ -168,7 +237,7 @@ class BacktestingEngine(object): if not results: break - fetch_counts = fetch_counts+fetch_size + fetch_counts = fetch_counts + len(results) if not self.listDataHistory: @@ -177,9 +246,8 @@ class BacktestingEngine(object): else: self.listDataHistory = self.listDataHistory + results - self.writeLog(u'历史TICK数据载入{0}条'.format(fetch_counts)) + self.writeLog(u'{1}~{2}历史TICK数据载入共{0}条'.format(fetch_counts,startDate,endDate)) - self.writeLog(u'历史TICK数据载入完成,{1}~{2},共{0}条'.format(count,startDate,endDate)) else: self.writeLog(u'MysqlDB未连接,请检查') @@ -311,10 +379,14 @@ class BacktestingEngine(object): event.dict_['data'] = data self.strategyEngine.updateMarketData(event) + # 保存交易到本地结果 + self.saveTradeData() + # 保存交易到数据库中 self.saveTradeDataToMysql() + t2 = datetime.now() self.writeLog(u'回测结束,{0},耗时:{1}秒'.format(str(t2),(t2-t1).seconds)) @@ -370,15 +442,33 @@ class BacktestingEngine(object): #---------------------------------------------------------------------- def saveTradeData(self): """保存交易记录""" - f = shelve.open('result.vn') - f['listTrade'] = self.listTrade - f.close() + #f = shelve.open('result.vn') + #f['listTrade'] = self.listTrade + #f.close() + + # 保存本地pickle文件 + resultPath=os.getcwd()+'\\result' + + if not os.path.isdir(resultPath): + os.mkdir(resultPath) + + resultFile = u'{0}\\{1}_Trade.pickle'.format(resultPath, self.Id) + + cache= open(resultFile, mode='w') + + cPickle.dump(self.listTrade,cache) + + cache.close() + """仿真订阅合约""" pass #---------------------------------------------------------------------- def saveTradeDataToMysql(self): """保存交易记录到mysql,added by Incense Lee""" + + self.connectMysql() + if self.__mysqlConnected: sql='insert into BackTest.TB_Trade (Id,symbol,orderRef,tradeID,direction,offset,price,volume,tradeTime,amount) values ' values = '' diff --git a/vn.strategy/strategydemo/demoBacktesting.py b/vn.strategy/strategydemo/demoBacktesting.py index 07bff70b..b1636cb5 100644 --- a/vn.strategy/strategydemo/demoBacktesting.py +++ b/vn.strategy/strategydemo/demoBacktesting.py @@ -8,7 +8,7 @@ import decimal def main(): """回测程序主函数""" # symbol = 'IF1506' - symbol = 'ag' + symbol = 'a' # 创建回测引擎 be = BacktestingEngine() @@ -22,14 +22,14 @@ def main(): be.connectMysql() # be.loadMongoDataHistory(symbol, datetime(2015,5,1), datetime.today()) # be.loadMongoDataHistory(symbol, datetime(2012,1,9), datetime(2012,1,14)) - be.loadMysqlDataHistory(symbol, datetime(2012,6,9), datetime(2012,7,20)) + be.loadDataHistory(symbol, datetime(2012,1,1), datetime(2012,1,30)) # 创建策略对象 setting = {} setting['fastAlpha'] = 0.2 setting['slowAlpha'] = 0.05 # setting['startDate'] = datetime(year=2015, month=5, day=20) - setting['startDate'] = datetime(year=2012, month=6, day=9) + setting['startDate'] = datetime(year=2012, month=1, day=1) se.createStrategy(u'EMA演示策略', symbol, SimpleEmaStrategy, setting) diff --git a/vn.strategy/strategydemo/strategyEngine.py b/vn.strategy/strategydemo/strategyEngine.py index d09e861b..cb3eaa9c 100644 --- a/vn.strategy/strategydemo/strategyEngine.py +++ b/vn.strategy/strategydemo/strategyEngine.py @@ -16,6 +16,9 @@ from vtConstant import * import MySQLdb +import os +import sys +import cPickle # 常量定义 OFFSET_OPEN = '0' # 开仓 @@ -166,6 +169,8 @@ class Order(object): self.status = '' # 报单状态代码 + self.preTradeID = '' # 上一成交单编号(用于平仓) + ######################################################################## class StopOrder(object): @@ -436,6 +441,25 @@ class StrategyEngine(object): id, 回测ID barList, 对象为Bar的列表 """ + + # 保存本地pickle文件 + resultPath=os.getcwd()+'\\result' + + if not os.path.isdir(resultPath): + os.mkdir(resultPath) + + resultFile = u'{0}\\{1}_Bar.pickle'.format(resultPath,id) + + cache= open(resultFile, mode='w') + + cPickle.dump(barList,cache) + + cache.close() + + # 保存数据库 + + self.__connectMysql() + if self.__mysqlConnected: sql='insert into BackTest.TB_Bar (Id, symbol ,open ,high ,low ,close ,date ,time ,datetime, volume, openInterest) values ' values = '' @@ -491,6 +515,23 @@ class StrategyEngine(object): 保存EMA到数据库 id,回测的编号 """ + + # 保存本地pickle文件 + resultPath=os.getcwd()+'\\result' + + if not os.path.isdir(resultPath): + os.mkdir(resultPath) + + resultFile = u'{0}\\{1}_Ema.pickle'.format(resultPath, id) + + cache= open(resultFile, mode='w') + + cPickle.dump(emaList,cache) + + cache.close() + + self.__connectMysql() + if self.__mysqlConnected: sql='insert into BackTest.TB_Ema (Id, symbol ,fastEMA,slowEMA ,date ,time ,datetime) values ' values = '' diff --git a/vn.training/day_test.py b/vn.training/day_test.py new file mode 100644 index 00000000..0ecab66b --- /dev/null +++ b/vn.training/day_test.py @@ -0,0 +1,17 @@ +# encoding: UTF-8 + +import datetime + +startDate = datetime.date(2012,6,1) +endDate = datetime.date(2012,12,31) +for i in range (0,(endDate - startDate).days +1, 10): + d1 = startDate + datetime.timedelta(days= i ) + + if (endDate - d1).days > 10: + d2 = startDate + datetime.timedelta(days= i+9 ) + else: + d2 = endDate + + i = i + 10 + + print str(d1),str(d2) \ No newline at end of file diff --git a/vn.training/draw.py b/vn.training/draw.py index adbb0d8d..d67e9936 100644 --- a/vn.training/draw.py +++ b/vn.training/draw.py @@ -951,7 +951,7 @@ if __name__ == '__main__': 程序目录= os.getcwd() - pfile= open(file=os.path.join(程序目录, '绘图数据.pickle'), mode='rb') + pfile= open(file=os.path.join(程序目录, 'draw_data.pickle'), mode='rb') 绘图数据= pickle.load(pfile) pfile.close()