vnpy/vn.strategy/strategydemo/backtestingEngine.py
2015-10-24 01:12:08 +08:00

515 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# encoding: UTF-8
import shelve
import MySQLdb
from eventEngine import *
from pymongo import MongoClient as Connection
from pymongo.errors import *
from datetime import datetime, timedelta, time
from strategyEngine import *
import sys
import os
import cPickle
########################################################################
class LimitOrder(object):
"""限价单对象"""
#----------------------------------------------------------------------
def __init__(self, symbol):
"""Constructor"""
self.symbol = symbol # 报单合约
self.price = 0 # 报单价格
self.volume = 0 # 报单合约数量
self.direction = None # 方向
self.offset = None # 开/平
#Modified by Incense Lee
self.orderTime = datetime.now() # 下单时间
########################################################################
class BacktestingEngine(object):
"""
回测引擎,作用:
1. 从数据库中读取数据并回放
2. 作为StrategyEngine创建时的参数传入
"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
self.eventEngine = EventEngine() # 实例化
# 策略引擎
self.strategyEngine = None # 通过setStrategyEngine进行设置
# TICK历史数据列表由于要使用For循环来实现仿真回放
# 使用list的速度比Numpy和Pandas都要更快
self.listDataHistory = []
# 限价单字典
self.dictOrder = {}
# 最新的TICK数据
self.currentData = None
# 回测的成交字典
self.listTrade = []
# 报单编号
self.orderRef = 0
# 成交编号
self.tradeID = 0
# 回测编号
self.Id = datetime.now().strftime('%Y%m%d-%H%M%S')
#----------------------------------------------------------------------
def setStrategyEngine(self, engine):
"""设置策略引擎"""
self.strategyEngine = engine
self.writeLog(u'策略引擎设置完成')
#----------------------------------------------------------------------
def connectMongo(self):
"""连接MongoDB数据库"""
try:
self.__mongoConnection = Connection()
self.__mongoConnected = True
self.__mongoTickDB = self.__mongoConnection['TickDB']
self.writeLog(u'回测引擎连接MongoDB成功')
except ConnectionFailure:
self.writeLog(u'回测引擎连接MongoDB失败')
#----------------------------------------------------------------------
def loadMongoDataHistory(self, symbol, startDate, endDate):
"""从Mongo载入历史TICK数据"""
if self.__mongoConnected:
collection = self.__mongoTickDB[symbol]
# 如果输入了读取TICK的最后日期
if endDate:
cx = collection.find({'date':{'$gte':startDate, '$lte':endDate}})
elif startDate:
cx = collection.find({'date':{'$gte':startDate}})
else:
cx = collection.find()
# 将TICK数据读入内存
self.listDataHistory = [data for data in cx]
self.writeLog(u'历史TICK数据载入完成')
else:
self.writeLog(u'MongoDB未连接请检查')
#----------------------------------------------------------------------
def connectMysql(self):
"""连接MysqlDB"""
try:
self.__mysqlConnection = MySQLdb.connect(host='vnpy.cloudapp.net', user='stockcn',
passwd='7uhb*IJN', db='stockcn', port=3306)
self.__mysqlConnected = True
self.writeLog(u'回测引擎连接MysqlDB成功')
except ConnectionFailure:
self.writeLog(u'回测引擎连接MysqlDB失败')
#----------------------------------------------------------------------
def loadDataHistory(self, symbol, startDate, endDate):
"""载入历史TICK数据,"""
if not endDate:
endDate = datetime.today()
# 看本地缓存是否存在
if self.__loadDataHistoryFromLocalCache(symbol, startDate, endDate):
self.writeLog(u'历史TICK数据从Cache载入')
return
intervalDays = 10
for i in range (0,(endDate - startDate).days +1, intervalDays):
d1 = startDate + timedelta(days = i )
if (endDate - d1).days > 10:
d2 = startDate + timedelta(days = i + intervalDays -1 )
else:
d2 = endDate
self.loadMysqlDataHistory(symbol, d1, d2)
self.writeLog(u'历史TICK数据共载入{0}'.format(len(self.listDataHistory)))
self.__saveDataHistoryToLocalCache(symbol, startDate, endDate)
def __loadDataHistoryFromLocalCache(self, symbol, startDate, endDate):
"""看本地缓存是否存在"""
cacheFolder = os.getcwd()+'\\cache'
cacheFile = u'{0}\\{1}_{2}_{3}.pickle'.format(cacheFolder,symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
if not os.path.isfile(cacheFile):
return False
else:
cache = open(cacheFile,mode='r')
self.listDataHistory = cPickle.load(cache)
return True
def __saveDataHistoryToLocalCache(self, symbol, startDate, endDate):
"""保存本地缓存"""
cacheFolder = os.getcwd()+'\\cache'
if not os.path.isdir(cacheFolder):
os.mkdir(cacheFolder)
cacheFile = u'{0}\\{1}_{2}_{3}.pickle'.format(cacheFolder,symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
if os.path.isfile(cacheFile):
return False
else:
cache= open(cacheFile, mode='w')
cPickle.dump(self.listDataHistory,cache)
cache.close()
return True
#----------------------------------------------------------------------
def loadMysqlDataHistory(self, symbol, startDate, endDate):
"""从Mysql载入历史TICK数据,"""
#Todo :判断开始和结束时间,如果间隔天过长,数据量会过大,需要批次提取。
try:
if self.__mysqlConnected:
#获取指针
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
if endDate:
sqlstring = ' select \'{0}\' as InstrumentID, str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB_{0}MI ' \
'where ndate between cast(\'{1}\' as date) and cast(\'{2}\' as date)'.format(symbol, startDate, endDate)
elif startDate:
sqlstring = ' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI ' \
'where ndate > cast(\'{1}\' as date)'.format( symbol, startDate)
else:
sqlstring =' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI '.format(symbol)
self.writeLog(sqlstring)
count = cur.execute(sqlstring)
#self.writeLog(u'历史TICK数据共{0}条'.format(count))
# 将TICK数据读入内存
#self.listDataHistory = cur.fetchall()
fetch_counts = 0
fetch_size = 1000
while True:
results = cur.fetchmany(fetch_size)
if not results:
break
fetch_counts = fetch_counts + len(results)
if not self.listDataHistory:
self.listDataHistory =results
else:
self.listDataHistory = self.listDataHistory + results
self.writeLog(u'{1}~{2}历史TICK数据载入共{0}'.format(fetch_counts,startDate,endDate))
else:
self.writeLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeLog(u'MysqlDB载入数据失败请检查.Error {0}'.format(e))
#----------------------------------------------------------------------
def getMysqlDeltaDate(self,symbol, startDate, decreaseDays):
"""从mysql库中获取交易日前若干天"""
try:
if self.__mysqlConnected:
#获取指针
cur = self.__mysqlConnection.cursor()
sqlstring='select distinct ndate from TB_{0}MI where ndate < ' \
'cast(\'{1}\' as date) order by ndate desc limit {2},1'.format(symbol, startDate, decreaseDays-1)
self.writeLog(sqlstring)
count = cur.execute(sqlstring)
if count > 0:
result = cur.fetchone()
return result[0]
else:
self.writeLog(u'MysqlDB没有查询结果请检查日期')
else:
self.writeLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeLog(u'MysqlDB载入数据失败请检查.Error {0}: {1}'.format(e.arg[0],e.arg[1]))
td = timedelta(days=3)
return startDate-td;
#----------------------------------------------------------------------
def processLimitOrder(self):
"""
处理限价单
为体现准确性回测引擎需要真实tick数据的买一或卖一价比对。
"""
for ref, order in self.dictOrder.items():
# 如果是买单且限价大于等于当前TICK的卖一价则假设成交
if order.direction == DIRECTION_BUY and \
order.price >= self.currentData['AskPrice1']:
self.executeLimitOrder(ref, order, self.currentData['AskPrice1'])
# 如果是卖单且限价低于当前TICK的买一价则假设全部成交
if order.direction == DIRECTION_SELL and \
order.price <= self.currentData['BidPrice1']:
self.executeLimitOrder(ref, order, self.currentData['BidPrice1'])
#----------------------------------------------------------------------
def executeLimitOrder(self, ref, order, price):
"""
模拟限价单成交处理
回测引擎模拟成交
"""
# 成交回报
self.tradeID = self.tradeID + 1
tradeData = {}
tradeData['InstrumentID'] = order.symbol
tradeData['OrderRef'] = ref
tradeData['TradeID'] = str(self.tradeID)
tradeData['Direction'] = order.direction
tradeData['OffsetFlag'] = order.offset
tradeData['Price'] = price
tradeData['Volume'] = order.volume
tradeData['TradeTime'] = order.orderTime
print tradeData
tradeEvent = Event()
tradeEvent.dict_['data'] = tradeData
self.strategyEngine.updateTrade(tradeEvent)
# 报单回报
orderData = {}
orderData['InstrumentID'] = order.symbol
orderData['OrderRef'] = ref
orderData['Direction'] = order.direction
orderData['CombOffsetFlag'] = order.offset
orderData['LimitPrice'] = price
orderData['VolumeTotalOriginal'] = order.volume
orderData['VolumeTraded'] = order.volume
orderData['InsertTime'] = order.orderTime
orderData['CancelTime'] = ''
orderData['FrontID'] = ''
orderData['SessionID'] = ''
orderData['OrderStatus'] = ''
orderEvent = Event()
orderEvent.dict_['data'] = orderData
self.strategyEngine.updateOrder(orderEvent)
# 记录该成交到列表中
self.listTrade.append(tradeData)
# 删除该限价单
del self.dictOrder[ref]
#----------------------------------------------------------------------
def startBacktesting(self):
"""开始回测"""
ISOTIMEFORMAT = '%Y-%m-%d %X'
t1 = datetime.now()
self.writeLog(u'开始回测,{0}'.format(str(t1 )))
for data in self.listDataHistory:
# 记录最新的TICK数据
self.currentData = data
# 处理限价单
self.processLimitOrder()
# 推送到策略引擎中
event = Event()
event.dict_['data'] = data
self.strategyEngine.updateMarketData(event)
# 保存交易到本地结果
self.saveTradeData()
# 保存交易到数据库中
self.saveTradeDataToMysql()
t2 = datetime.now()
self.writeLog(u'回测结束,{0},耗时:{1}'.format(str(t2),(t2-t1).seconds))
# 保存策略过程数据到数据库
self.strategyEngine.saveData(self.Id)
#----------------------------------------------------------------------
def sendOrder(self, instrumentid, exchangeid, price, pricetype, volume, direction, offset, orderTime=datetime.now()):
"""回测发单"""
order = LimitOrder(instrumentid) # 限价报单
order.price = price # 报单价格
order.direction = direction # 买卖方向
order.volume = volume # 报单数量
order.offset = offset
order.orderTime = orderTime # 报单时间
self.orderRef = self.orderRef + 1 # 报单编号
self.dictOrder[str(self.orderRef)] = order
return str(self.orderRef)
#----------------------------------------------------------------------
def cancelOrder(self, instrumentid, exchangeid, orderref, frontid, sessionid):
"""回测撤单"""
try:
del self.dictOrder[orderref]
except KeyError:
pass
#----------------------------------------------------------------------
def writeLog(self, log):
"""写日志"""
print log
#----------------------------------------------------------------------
def subscribe(self, symbol, exchange):
"""仿真订阅合约"""
pass
#----------------------------------------------------------------------
def selectInstrument(self, symbol):
"""读取合约数据"""
d = {}
d['ExchangeID'] = 'BackTesting'
return d
#----------------------------------------------------------------------
def saveTradeData(self):
"""保存交易记录"""
#f = shelve.open('result.vn')
#f['listTrade'] = self.listTrade
#f.close()
# 保存本地pickle文件
resultPath=os.getcwd()+'\\result'
if not os.path.isdir(resultPath):
os.mkdir(resultPath)
resultFile = u'{0}\\{1}_Trade.pickle'.format(resultPath, self.Id)
cache= open(resultFile, mode='w')
cPickle.dump(self.listTrade,cache)
cache.close()
"""仿真订阅合约"""
pass
#----------------------------------------------------------------------
def saveTradeDataToMysql(self):
"""保存交易记录到mysql,added by Incense Lee"""
self.connectMysql()
if self.__mysqlConnected:
sql='insert into BackTest.TB_Trade (Id,symbol,orderRef,tradeID,direction,offset,price,volume,tradeTime,amount) values '
values = ''
print u'{0}条交易记录.'.format(len(self.listTrade))
if len(self.listTrade) == 0:
return
for tradeItem in self.listTrade:
if len(values) > 0:
values = values + ','
if tradeItem['OffsetFlag'] == '0':
amount = 0-float(tradeItem['Price'])*int(tradeItem['Volume'])
else:
amount = float(tradeItem['Price'])*int(tradeItem['Volume'])
values = values + '(\'{0}\',\'{1}\',{2},{3},{4},{5},{6},{7},\'{8}\',{9})'.format(
self.Id,
tradeItem['InstrumentID'],
tradeItem['OrderRef'],
tradeItem['TradeID'],
tradeItem['Direction'],
tradeItem['OffsetFlag'],
tradeItem['Price'],
tradeItem['Volume'],
tradeItem['TradeTime'].strftime('%Y-%m-%d %H:%M:%S'),amount)
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
try:
cur.execute(sql+values)
self.__mysqlConnection.commit()
except Exception, e:
print e
else:
self.saveTradeData()