diff --git a/vnpy/trader/app/ctaStrategy/arbTemplate.py b/vnpy/trader/app/ctaStrategy/arbTemplate.py index 2cf3a154..9ebf2d7f 100644 --- a/vnpy/trader/app/ctaStrategy/arbTemplate.py +++ b/vnpy/trader/app/ctaStrategy/arbTemplate.py @@ -26,7 +26,8 @@ class arbTemplate(object): base_symbol = EMPTY_STRING # 交易主货币 btc quote_symbol = EMPTY_STRING # 基准货币 usdt spot_symbol = EMPTY_STRING # 现货交易所币对 btc_usdt - future_symbol = EMPTY_STRING # 期货合约 btc:next_week:10 + future_symbol = EMPTY_STRING # 期货合约 btc:next_week:10 + future_net_symbol = EMPTY_STRING # 合约账号内的期货合约净仓symbol btc.[future] Leg1Symbol = EMPTY_STRING # 带交易所信息的symbol,如: btc_usdt.OKEX Leg2Symbol = EMPTY_STRING # 带交易所信息的symbol,如: btc:next_week:10.OKEX exchange = EMPTY_STRING @@ -46,6 +47,7 @@ class arbTemplate(object): 'quote_symbol', 'spot_symbol', 'future_symbol', + 'future_net_symbol', 'Leg1Symbol', 'Leg2Symbol', 'exchange', diff --git a/vnpy/trader/app/ctaStrategy/ctaBacktesting.py b/vnpy/trader/app/ctaStrategy/ctaBacktesting.py index 6e6bf6b1..27ebd103 100644 --- a/vnpy/trader/app/ctaStrategy/ctaBacktesting.py +++ b/vnpy/trader/app/ctaStrategy/ctaBacktesting.py @@ -24,6 +24,8 @@ import copy import pandas as pd import re import traceback +import decimal +import numpy as np from vnpy.trader.app.ctaStrategy.ctaBase import * from vnpy.trader.vtConstant import * @@ -31,6 +33,8 @@ from vnpy.trader.vtGateway import VtOrderData, VtTradeData from vnpy.trader.vtFunction import loadMongoSetting from vnpy.trader.vtEvent import * from vnpy.trader.setup_logger import setup_logger +from vnpy.trader.data_source import DataSource +from vnpy.trader.app.ctaStrategy.ctaEngine import PositionBuffer ######################################################################## class BacktestingEngine(object): @@ -177,6 +181,8 @@ class BacktestingEngine(object): self.logger = None + self.useBreakoutMode = False + def getAccountInfo(self): """返回账号的实时权益,可用资金,仓位比例,投资仓位比例上限""" if self.netCapital == EMPTY_FLOAT: @@ -343,6 +349,119 @@ class BacktestingEngine(object): # 保存本地cache文件 self.__saveDataHistoryToLocalCache(symbol, startDate, endDate) + def runBackTestingWithMongoDBTicks(self, symbol): + """ + 根据测试的每一天,从MongoDB载入历史数据,并推送Tick至回测函数 + """ + + self.capital = self.initCapital # 更新设置期初资金 + + if not self.dataStartDate: + self.writeCtaLog(u'回测开始日期未设置。') + return + + if not self.dataEndDate: + self.dataEndDate = datetime.today() + + if len(self.symbol) < 1: + self.writeCtaLog(u'回测对象未设置。') + return + + # 首先根据回测模式,确认要使用的数据类 + if self.mode == self.BAR_MODE: + self.writeCtaLog(u'本回测仅支持tick模式') + return + else: + dataClass = CtaTickData + func = self.newTick + + self.output(u'开始回测') + + # self.strategy.inited = True + self.strategy.onInit() + self.output(u'策略初始化完成') + + self.strategy.trading = True + self.strategy.onStart() + self.output(u'策略启动完成') + + # isOffline = False # WJ + isOffline = True + host, port, log = loadMongoSetting() + + self.dbClient = pymongo.MongoClient(host, port) + symbol = self.strategy.shortSymbol + self.symbol[-2:] + self.strategy.vtSymbol = symbol + collection = self.dbClient[self.dbName][symbol] + + self.output(u'开始载入数据') + + # 载入回测数据 + if not self.dataEndDate: + self.dataEndDate = datetime.now() + + testdays = (self.dataEndDate - self.dataStartDate).days + + if testdays < 1: + self.writeCtaLog(u'回测时间不足') + return + + # 循环每一天 + for i in range(0, testdays): + testday = self.dataStartDate + timedelta(days=i) + + # 看本地缓存是否存在 + cachefilename = u'{0}_{1}_{2}'.format(self.symbol, symbol, testday.strftime('%Y%m%d')) + rawTicks = self.__loadTicksFromLocalCache(cachefilename) + + dt = None + + if len(rawTicks) < 1 and isOffline == False: + + testday_monrning = testday # testday.replace(hour=0, minute=0, second=0, microsecond=0) + testday_midnight = testday + timedelta( + days=1) # testday.replace(hour=23, minute=59, second=59, microsecond=999999) + + query_time = datetime.now() + # 载入初始化需要用的数据 + flt = {'tradingDay': testday.strftime('%Y%m%d')} # WJ: using TradingDay instead of calandar day + # flt = {'datetime': {'$gte': testday_monrning, '$lt': testday_midnight}} + initCursor = collection.find(flt).sort('datetime', pymongo.ASCENDING) + + process_time = datetime.now() + # 将数据从查询指针中读取出,并生成列表 + count_ticks = 0 + + for d in initCursor: + data = dataClass() + data.__dict__ = d + rawTicks.append(data) + count_ticks += 1 + + self.writeCtaLog(u'回测日期{0},数据量:{1},查询耗时:{2},回测耗时:{3}' + .format(testday.strftime('%Y-%m-%d'), count_ticks, str(datetime.now() - query_time), + str(datetime.now() - process_time))) + + # 保存本地cache文件 + if count_ticks > 0: + self.__saveTicksToLocalCache(cachefilename, rawTicks) + + for t in rawTicks: + # 排除涨停/跌停的数据 + if ((t.askPrice1 == float('1.79769E308') or t.askPrice1 == 0) and t.askVolume1 == 0) \ + or ((t.bidPrice1 == float('1.79769E308') or t.bidPrice1 == 0) and t.bidVolume1 == 0): + continue + + # 推送到策略中 + self.newTick(t) + + # 保存最后一个Tick,确保savingDailyData()工作正常 + self.last_leg1_tick = t + self.last_leg1_tick.vtSymbol = symbol + + # 记录每日净值 + if len(rawTicks) > 1: + self.savingDailyData(testday, self.capital, self.maxCapital, self.totalCommission) def __loadDataHistoryFromLocalCache(self, symbol, startDate, endDate): """看本地缓存是否存在 @@ -531,6 +650,63 @@ class BacktestingEngine(object): return tick + def __barToTick(self, bar): + """ + 数据库查询返回的bar结构,转换为tick对象 + added by Wenjian Du """ + + # TODO + tick = CtaTickData() + tick.symbol = bar.symbol + + # 创建TICK数据对象并更新数据 + tick.vtSymbol = bar.symbol + # tick.openPrice = data['OpenPrice'] + # tick.highPrice = data['HighestPrice'] + # tick.lowPrice = data['LowestPrice'] + tick.lastPrice = float(bar.close) + + # bug fix: + # ctp日常传送的volume数据,是交易日日内累加值。数据库的volume,是数据商自行计算整理的 + # 因此,改为使用DayVolume,与CTP实盘一致 + tick.volume = bar.volume + tick.openInterest = bar.openInterest + + # tick.upperLimit = data['UpperLimitPrice'] + # tick.lowerLimit = data['LowerLimitPrice'] + + tick.datetime = bar.datetime + timedelta(seconds=self.barTimeInterval) + tick.date = tick.datetime.strftime('%Y-%m-%d') + tick.time = tick.datetime.strftime('%H:%M:%S') + # 数据库中并没有tradingDay的数据,回测时,暂时按照date授予。 + tick.tradingDay = bar.tradingDay + + tick.bidPrice1 = float(bar.close) + # tick.bidPrice2 = data['BidPrice2'] + # tick.bidPrice3 = data['BidPrice3'] + # tick.bidPrice4 = data['BidPrice4'] + # tick.bidPrice5 = data['BidPrice5'] + + tick.askPrice1 = float(bar.close) + # tick.askPrice2 = data['AskPrice2'] + # tick.askPrice3 = data['AskPrice3'] + # tick.askPrice4 = data['AskPrice4'] + # tick.askPrice5 = data['AskPrice5'] + + tick.bidVolume1 = bar.volume + # tick.bidVolume2 = data['BidVolume2'] + # tick.bidVolume3 = data['BidVolume3'] + # tick.bidVolume4 = data['BidVolume4'] + # tick.bidVolume5 = data['BidVolume5'] + + tick.askVolume1 = bar.volume + # tick.askVolume2 = data['AskVolume2'] + # tick.askVolume3 = data['AskVolume3'] + # tick.askVolume4 = data['AskVolume4'] + # tick.askVolume5 = data['AskVolume5'] + + return tick + #---------------------------------------------------------------------- def getMysqlDeltaDate(self,symbol, startDate, decreaseDays): """从mysql库中获取交易日前若干天 @@ -837,6 +1013,204 @@ class BacktestingEngine(object): # ---------------------------------------------------------------------- + def runBackTestingWithTickFile(self, mainPath, symbol): + """运行Tick回测(使用本地tick TXT csv数据) + 参数:代码 rb1610 + added by WenjianDu + 原始的tick,分别存放在白天目录1和夜盘目录2中,每天都有各个合约的数据 + Z:\ticks\SHFE\201606\RB\0601\ + RB1610.txt + RB1701.txt + .... + Z:\ticks\SHFE_night\201606\RB\0601 + RB1610.txt + RB1701.txt + .... + + 夜盘目录为自然日,不是交易日。 + + 按照回测的开始日期,到结束日期,循环每一天。 + 每天优先读取日盘数据,再读取夜盘数据。 + 读取tick(如RB1610),灌输到策略的onTick中。 + """ + self.capital = self.initCapital # 更新设置期初资金 + + if len(symbol) < 1: + self.writeCtaLog(u'合约为空') + return + + # 获得tick + self.writeCtaLog(u'arbSymbol:{0}'.format(symbol)) + + if not self.dataStartDate: + self.writeCtaLog(u'回测开始日期未设置。') + return + # RB + if len(self.symbol) < 1: + self.writeCtaLog(u'回测对象未设置。') + return + + if not self.dataEndDate: + self.dataEndDate = datetime.today() + + # 首先根据回测模式,确认要使用的数据类 + if self.mode == self.BAR_MODE: + self.writeCtaLog(u'本回测仅支持tick模式') + return + + testdays = (self.dataEndDate - self.dataStartDate).days + + if testdays < 1: + self.writeCtaLog(u'回测时间不足') + return + + for i in range(0, testdays): + testday = self.dataStartDate + timedelta(days=i) + self.output(u'回测日期:{0}'.format(testday)) + # 白天数据 + self.__loadTxtTicks(mainPath, testday, symbol) + # 撤销所有之前的orders + if self.symbol: + self.cancelOrders(self.symbol) + # # 夜盘数据 + # self.__loadTxtTicks(mainPath + '_night', testday, symbol) + self.savingDailyData(testday, self.capital, self.maxCapital, self.totalCommission) + + def __loadTxtTicks(self, mainPath, testday, symbol): + + self.writeCtaLog(u'加载回测日期:{0}\{1}的tick'.format(mainPath, testday)) + + cachefilename = u'{0}_{1}_{2}_{3}'.format(self.symbol, symbol, mainPath, testday.strftime('%Y%m%d')) + + rawTicks = self.__loadTicksFromLocalCache(cachefilename) + + dt = None + + if len(rawTicks) < 1: + + # rawFile = u'F:\\FutureData\\{0}\\{1}\\{2}\\{3}\\{4}.txt' \ + # .format(mainPath, testday.strftime('%Y%m'), self.symbol, testday.strftime('%m%d'), symbol) + rawFile = u'/home/wenjiand/Downloads/FutureData/{0}/{1}/{2}/{3}/{4}.txt' \ + .format(mainPath, testday.strftime('%Y%m'), self.strategy.shortSymbol, testday.strftime('%m%d'), self.strategy.symbol.upper()) + if not os.path.isfile(rawFile): + self.writeCtaLog(u'{0}文件不存在'.format(rawFile)) + return + + # 先读取raw的数据到目录,以日期时间为key + tempTicks = {} + + rawCsvReadFile = open(rawFile, 'r', encoding='utf8') + # reader = csv.DictReader((line.replace('\0',' ') for line in rawCsvReadFile), delimiter=",") + reader = csv.DictReader(rawCsvReadFile, delimiter=",") + self.writeCtaLog(u'加载{0}'.format(rawFile)) + for row in reader: + tick = CtaTickData() + + tick.symbol = self.symbol + tick.vtSymbol = symbol + + tick.date = testday.strftime('%Y%m%d') + tick.tradingDay = tick.date + tick.time = row['Time'] + + try: + tick.datetime = datetime.strptime(tick.date + ' ' + tick.time, '%Y%m%d %H:%M:%S.%f') + except Exception as ex: + self.writeCtaError(u'日期转换错误:{0},{1}:{2}'.format(tick.date + ' ' + tick.time, Exception, ex)) + continue + + # 修正毫秒 + if tick.datetime.replace(microsecond=0) == dt: + # 与上一个tick的时间(去除毫秒后)相同,修改为500毫秒 + tick.datetime = tick.datetime.replace(microsecond=500) + tick.time = tick.datetime.strftime('%H:%M:%S.%f') + + else: + tick.datetime = tick.datetime.replace(microsecond=0) + tick.time = tick.datetime.strftime('%H:%M:%S.%f') + + dt = tick.datetime + + tick.lastPrice = float(row['LastPrice']) + tick.volume = int(float(row['LVolume'])) + tick.bidPrice1 = float(row['BidPrice']) # 叫买价(价格低) + tick.bidVolume1 = int(float(row['BidVolume'])) + tick.askPrice1 = float(row['AskPrice']) # 叫卖价(价格高) + tick.askVolume1 = int(float(row['AskVolume'])) + + # 排除涨停/跌停的数据 + if (tick.bidPrice1 == float('1.79769E308') and tick.bidVolume1 == 0) \ + or (tick.askPrice1 == float('1.79769E308') and tick.askVolume1 == 0): + continue + + dtStr = tick.date + ' ' + tick.time + if dtStr in tempTicks: + pass + # self.writeCtaError(u'日内数据重复,异常,数据时间为:{0}'.format(dtStr)) + else: + tempTicks[dtStr] = tick + + rawTicks.append(tick) + + del tempTicks + + # 保存到历史目录 + if len(rawTicks) > 0: + self.__saveTicksToLocalCache(cachefilename, rawTicks) + + for t in rawTicks: + # 推送到策略中 + self.newTick(t) + + # 保存最后一个Tick,确保savingDailyData()工作正常 + self.last_leg1_tick = t + self.last_leg1_tick.vtSymbol = symbol + + def __loadTicksFromLocalCache(self, filename): + """从本地缓存中,加载数据""" + # 运行路径下cache子目录 + cacheFolder = os.getcwd() + '/cache' + # cacheFolder = '/home/wenjiand/Workspaces/huafu-vnpy/vnpy/trader/app/ctaStrategy/strategy/cache' + + # cache文件 + cacheFile = u'{0}/{1}.pickle'. \ + format(cacheFolder, filename) + + if not os.path.isfile(cacheFile): + return [] + else: + # 从cache文件加载 + cache = open(cacheFile, mode='rb') + l = cPickle.load(cache) + cache.close() + return l + + def __saveTicksToLocalCache(self, filename, arbticks): + """保存价差tick到本地缓存目录""" + # 运行路径下cache子目录 + cacheFolder = os.getcwd() + '/cache' + # cacheFolder = '/home/wenjiand/Workspaces/huafu-vnpy/vnpy/trader/app/ctaStrategy/strategy/cache' + + # 创建cache子目录 + if not os.path.isdir(cacheFolder): + os.mkdir(cacheFolder) + + # cache 文件名 + cacheFile = u'{0}/{1}.pickle'. \ + format(cacheFolder, filename) + + # 重复存在 返回 + if os.path.isfile(cacheFile): + return False + + else: + # 写入cache文件 + cache = open(cacheFile, mode='wb') + cPickle.dump(arbticks, cache) + cache.close() + return True + + # ---------------------------------------------------------------------- def runBackTestingWithArbTickFile2(self, leg1MainPath,leg2MainPath, arbSymbol): """运行套利回测(使用本地tick csv数据) 参数:套利代码 SP rb1610&rb1701 @@ -1599,7 +1973,10 @@ class BacktestingEngine(object): bar.low = self.roundToPriceTick(float(row['low'])) bar.close = self.roundToPriceTick(float(row['close'])) bar.volume = float(row['volume']) if len(row['volume'])>0 else 0 - barEndTime = datetime.strptime(row['index'], '%Y-%m-%d %H:%M:%S') + if '-' in row['index']: + barEndTime = datetime.strptime(row['index'], '%Y-%m-%d %H:%M:%S') + else: + barEndTime = datetime.strptime(row['datetime'], '%Y%m%d%H%M%S') # 使用Bar的开始时间作为datetime bar.datetime = barEndTime - timedelta(seconds=self.barTimeInterval) @@ -1607,7 +1984,10 @@ class BacktestingEngine(object): bar.date = bar.datetime.strftime('%Y-%m-%d') bar.time = bar.datetime.strftime('%H:%M:%S') if 'trading_date' in row: - bar.tradingDay = row['trading_date'] + if len(row['trading_date']) is 8: + bar.tradingDay = row['trading_date'][0:4] + '-' + row['trading_date'][4:6] + '-' + row['trading_date'][6:] + else: + bar.tradingDay = row['trading_date'] else: if bar.datetime.hour >=21 and not self.is_7x24: if bar.datetime.isoweekday() == 5: @@ -1629,7 +2009,16 @@ class BacktestingEngine(object): self.maxCapital,self.totalCommission,benchmark=bar.close) last_tradingDay = bar.tradingDay - self.newBar(bar) + # Simulate latest tick and send it to Strategy + simTick = self.__barToTick(bar) + # self.tick = simTick + self.strategy.curTick = simTick + + # Check the order triggers and deliver the bar to the Strategy + if self.useBreakoutMode is False: + self.newBar(bar) + else: + self.newBarForBreakout(bar) if not self.strategy.trading and self.strategyStartDate < bar.datetime: self.strategy.trading = True @@ -1641,7 +2030,132 @@ class BacktestingEngine(object): return except Exception as ex: - self.writeCtaLog(u'{}:{}'.format(str(ex),traceback.format_exc())) + self.writeCtaError(u'回测异常导致停止') + self.writeCtaError(u'{},{}'.format(str(ex),traceback.format_exc())) + return + + #---------------------------------------------------------------------- + def runBackTestingWithDataSource(self): + """运行回测(使用本地csv数据) + added by IncenseLee + """ + self.capital = self.initCapital # 更新设置期初资金 + if not self.dataStartDate: + self.writeCtaLog(u'回测开始日期未设置。') + return + + if not self.dataEndDate: + self.dataEndDate = datetime.today() + + if len(self.symbol) < 1: + self.writeCtaLog(u'回测对象未设置。') + return + + # 首先根据回测模式,确认要使用的数据类 + if not self.mode == self.BAR_MODE: + self.writeCtaLog(u'文件仅支持bar模式,若扩展tick模式,需要修改本方法') + return + + # self.output(u'开始回测') + + # self.strategy.inited = True + self.strategy.onInit() + self.output(u'策略初始化完成') + + self.strategy.trading = True + self.strategy.onStart() + self.output(u'策略启动完成') + + self.output(u'开始载入数据') + + # 载入回测数据 + testdays = (self.dataEndDate - self.dataStartDate).days + + rawBars = [] + # 看本地缓存是否存在 + cachefilename = u'{0}_{1}_{2}'.format(self.symbol, self.dataStartDate.strftime('%Y%m%d'), + self.dataEndDate.strftime('%Y%m%d')) + rawBars = self.__loadTicksFromLocalCache(cachefilename) + + if len(rawBars) < 1: + self.writeCtaLog(u'从数据库中读取数据') + + query_time = datetime.now() + ds = DataSource() + start_date = self.dataStartDate.strftime('%Y-%m-%d') + end_date = self.dataEndDate.strftime('%Y-%m-%d') + fields = ['open', 'close', 'high', 'low', 'volume', 'open_interest', 'limit_up', 'limit_down', + 'trading_date'] + last_bar_dt = None + + df = ds.get_price(order_book_id=self.strategy.symbol, start_date=start_date, + end_date=end_date, frequency='1m', fields=fields) + + process_time = datetime.now() + # 将数据从查询指针中读取出,并生成列表 + count_bars = 0 + self.writeCtaLog(u'一共获取{}条{}分钟数据'.format(len(df), '1m')) + for idx in df.index: + row = df.loc[idx] + # self.writeCtaLog('{}: {}, o={}, h={}, l={}, c={}'.format(count_bars, datetime.strptime(str(idx), '%Y-%m-%d %H:%M:00'), + # row['open'], row['high'], row['low'], row['close'])) + bar = CtaBarData() + bar.vtSymbol = self.symbol + bar.symbol = self.symbol + last_bar_dt = datetime.strptime(str(idx), '%Y-%m-%d %H:%M:00') + bar.datetime = last_bar_dt - timedelta(minutes=1) + bar.date = bar.datetime.strftime('%Y-%m-%d') + bar.time = bar.datetime.strftime('%H:%M:00') + bar.tradingDay = datetime.strptime(str(int(row['trading_date'])), '%Y%m%d') + bar.open = float(row['open']) + bar.high = float(row['high']) + bar.low = float(row['low']) + bar.close = float(row['close']) + bar.volume = int(row['volume']) + rawBars.append(bar) + count_bars += 1 + + self.writeCtaLog(u'回测日期{}-{},数据量:{},查询耗时:{},回测耗时:{}' + .format(self.dataStartDate.strftime('%Y-%m-%d'), self.dataEndDate.strftime('%Y%m%d'), + count_bars, str(datetime.now() - query_time), str(datetime.now() - process_time))) + + # 保存本地cache文件 + if count_bars > 0: + self.__saveTicksToLocalCache(cachefilename, rawBars) + + if len(rawBars) < 1: + self.writeCtaLog(u'ERROR 拿不到指定日期的数据,结束') + return + + self.output(u'开始回放数据') + last_tradingDay = 0 + for bar in rawBars: + # self.writeCtaLog(u'{} o:{};h:{};l:{};c:{},v:{},tradingDay:{},H2_count:{}' + # .format(bar.date+' '+bar.time, bar.open, bar.high, + # bar.low, bar.close, bar.volume, bar.tradingDay, self.lineH2.m1_bars_count)) + + # if not (bar.datetime < self.dataStartDate or bar.datetime >= self.dataEndDate): + if True: + if last_tradingDay == 0: + last_tradingDay = bar.tradingDay + elif last_tradingDay != bar.tradingDay: + if last_tradingDay is not None: + self.savingDailyData(last_tradingDay, self.capital, self.maxCapital, self.totalCommission) + last_tradingDay = bar.tradingDay + + # Simulate latest tick and send it to Strategy + simTick = self.__barToTick(bar) + # self.tick = simTick + self.strategy.curTick = simTick + + # Check the order triggers and deliver the bar to the Strategy + if self.useBreakoutMode is False: + self.newBar(bar) + else: + self.newBarForBreakout(bar) + + if self.netCapital < 0: + self.writeCtaError(u'净值低于0,回测停止') return #---------------------------------------------------------------------- @@ -1828,6 +2342,17 @@ class BacktestingEngine(object): self.__sendOnBarEvent(bar) # 推送K线到事件 self.last_bar = bar + # ---------------------------------------------------------------------- + def newBarForBreakout(self, bar): + """新的K线""" + self.bar = bar + self.dt = bar.datetime + self.strategy.onBar(bar) # 推送K线到策略中 + self.crossLimitOrder() # 先撮合限价单 + self.crossStopOrder() # 再撮合停止单 + self.__sendOnBarEvent(bar) # 推送K线到事件 + self.last_bar = bar + # ---------------------------------------------------------------------- def newTick(self, tick): """新的Tick""" @@ -1848,14 +2373,14 @@ class BacktestingEngine(object): self.strategy.name = self.strategy.className self.strategy.onInit() - #self.strategy.onStart() + self.strategy.onStart() # --------------------------------------------------------------------- def saveStrategyData(self): """保存策略数据""" if self.strategy is None: return - + self.writeCtaLog(u'save strategy data') self.strategy.saveData() #---------------------------------------------------------------------- @@ -1943,6 +2468,10 @@ class BacktestingEngine(object): so.stopOrderID = stopOrderID so.status = STOPORDER_WAITING + # added by IncenseLee + so.gatewayName = STOPORDERPREFIX[0:-1] + so.orderId = str(self.stopOrderCount) + if orderType == CTAORDER_BUY: so.direction = DIRECTION_LONG so.offset = OFFSET_OPEN @@ -1981,18 +2510,21 @@ class BacktestingEngine(object): buyBestCrossPrice = self.roundToPriceTick(self.bar.open) + self.priceTick # 在当前时间点前发出的买入委托可能的最优成交价 sellBestCrossPrice = self.roundToPriceTick(self.bar.open) - self.priceTick # 在当前时间点前发出的卖出委托可能的最优成交价 vtSymbol = self.bar.vtSymbol + symbol = self.bar.symbol else: buyCrossPrice = self.tick.askPrice1 sellCrossPrice = self.tick.bidPrice1 buyBestCrossPrice = self.tick.askPrice1 sellBestCrossPrice = self.tick.bidPrice1 vtSymbol = self.tick.vtSymbol - + symbol = self.tick.symbol + # 遍历限价单字典中的所有限价单 - for orderID, order in list(self.workingLimitOrderDict.items()): + workingLimitOrderDictClone = copy.deepcopy(self.workingLimitOrderDict) + for orderID, order in list(workingLimitOrderDictClone.items()): # 判断是否会成交 - buyCross = order.direction == DIRECTION_LONG and order.price >= buyCrossPrice and vtSymbol.lower() == order.vtSymbol.lower() - sellCross = order.direction == DIRECTION_SHORT and order.price <= sellCrossPrice and vtSymbol.lower() == order.vtSymbol.lower() + buyCross = order.direction == DIRECTION_LONG and order.price >= buyCrossPrice and (vtSymbol.lower() == order.vtSymbol.lower() or symbol.lower() == order.vtSymbol.lower()) + sellCross = order.direction == DIRECTION_SHORT and order.price <= sellCrossPrice and (vtSymbol.lower() == order.vtSymbol.lower() or symbol.lower() == order.vtSymbol.lower()) # 如果发生了成交 if buyCross or sellCross: @@ -2014,10 +2546,16 @@ class BacktestingEngine(object): # 2. 假设在上一根K线结束(也是当前K线开始)的时刻,策略发出的委托为限价105 # 3. 则在实际中的成交价会是100而不是105,因为委托发出时市场的最优价格是100 if buyCross: - trade.price = min(order.price, buyBestCrossPrice) + if self.useBreakoutMode is False: + trade.price = min(order.price, buyBestCrossPrice) + else: + trade.price = max(order.price, buyBestCrossPrice) self.strategy.pos += order.totalVolume else: - trade.price = max(order.price, sellBestCrossPrice) + if self.useBreakoutMode is False: + trade.price = max(order.price, sellBestCrossPrice) + else: + trade.price = min(order.price, sellBestCrossPrice) self.strategy.pos -= order.totalVolume trade.volume = order.totalVolume @@ -2027,7 +2565,16 @@ class BacktestingEngine(object): self.tradeDict[tradeID] = trade self.writeCtaLog(u'TradeId:{0}'.format(tradeID)) - + + # 更新持仓缓存数据 # TODO: do we need this? + posBuffer = self.posBufferDict.get(trade.vtSymbol, None) + if not posBuffer: + posBuffer = PositionBuffer() + posBuffer.vtSymbol = trade.vtSymbol + self.posBufferDict[trade.vtSymbol] = posBuffer + posBuffer.updateTradeData(trade) + self.writeCtaLog(u'DEBUG-- [ctaBacktesting] crossLimitOrder: TradeId:{}, posBuffer = {}'.format(tradeID, posBuffer.toStr())) + # 推送委托数据 order.tradedVolume = order.totalVolume order.status = STATUS_ALLTRADED @@ -2053,17 +2600,20 @@ class BacktestingEngine(object): sellCrossPrice = self.bar.low # 若卖出方向限价单价格高于该价格,则会成交 bestCrossPrice = self.bar.open # 最优成交价,买入停止单不能低于,卖出停止单不能高于 vtSymbol = self.bar.vtSymbol + symbol = self.bar.symbol else: buyCrossPrice = self.tick.lastPrice sellCrossPrice = self.tick.lastPrice bestCrossPrice = self.tick.lastPrice vtSymbol = self.tick.vtSymbol - + symbol = self.tick.symbol + # 遍历停止单字典中的所有停止单 - for stopOrderID, so in self.workingStopOrderDict.items(): + workingStopOrderDictClone = copy.deepcopy(self.workingStopOrderDict) + for stopOrderID, so in workingStopOrderDictClone.items(): # 判断是否会成交 - buyCross = so.direction == DIRECTION_LONG and so.price <= buyCrossPrice and vtSymbol.lower() == so.vtSymbol.lower() - sellCross = so.direction == DIRECTION_SHORT and so.price >= sellCrossPrice and vtSymbol.lower() == so.vtSymbol.lower() + buyCross = so.direction == DIRECTION_LONG and so.price <= buyCrossPrice and (vtSymbol.lower() == so.vtSymbol.lower() or symbol.lower() == order.vtSymbol.lower()) + sellCross = so.direction == DIRECTION_SHORT and so.price >= sellCrossPrice and (vtSymbol.lower() == so.vtSymbol.lower() or symbol.lower() == order.vtSymbol.lower()) # 如果发生了成交 if buyCross or sellCross: @@ -2096,6 +2646,14 @@ class BacktestingEngine(object): self.tradeDict[tradeID] = trade + # 更新持仓缓存数据 # TODO: do we need this? + posBuffer = self.posBufferDict.get(trade.vtSymbol, None) + if not posBuffer: + posBuffer = PositionBuffer() + posBuffer.vtSymbol = trade.vtSymbol + self.posBufferDict[trade.vtSymbol] = posBuffer + posBuffer.updateTradeData(trade) + # 推送委托数据 so.status = STOPORDER_TRIGGERED @@ -2111,13 +2669,16 @@ class BacktestingEngine(object): order.tradedVolume = so.volume order.status = STATUS_ALLTRADED order.orderTime = trade.tradeTime + order.gatewayName = so.gatewayName self.strategy.onOrder(order) self.limitOrderDict[orderID] = order # 从字典中删除该限价单 - if stopOrderID in self.workingStopOrderDict: + try: del self.workingStopOrderDict[stopOrderID] + except Exception as ex: + self.writeCtaError(u'crossStopOrder exception:{},{}'.format(str(ex), traceback.format_exc())) # 若采用实时计算净值 if self.calculateMode == self.REALTIME_MODE: @@ -2185,12 +2746,18 @@ class BacktestingEngine(object): def writeCtaError(self, content,strategy_name=None): """记录异常""" self.output(u'Error:{}'.format(content)) - self.writeCtaLog(content) + if self.logger: + self.logger.error(content) + else: + self.createLogger() def writeCtaWarning(self, content,strategy_name=None): """记录告警""" self.output(u'Warning:{}'.format(content)) - self.writeCtaLog(content) + if self.logger: + self.logger.warning(content) + else: + self.createLogger() def writeCtaNotification(self,content,strategy_name=None): """记录通知""" @@ -2245,12 +2812,14 @@ class BacktestingEngine(object): gr = None # 组合的交易结果 coverVolume = trade.volume - + self.writeCtaLog(u'平空:{}'.format(coverVolume)) while coverVolume > 0: if len(self.shortPosition) == 0: self.writeCtaError(u'异常!没有开空仓的数据') raise Exception(u'realtimeCalculate2() Exception,没有开空仓的数据') return + cur_short_pos_list = [s_pos.volume for s_pos in self.shortPosition] + self.writeCtaLog(u'当前空单:{}'.format(cur_short_pos_list)) pop_indexs = [i for i, val in enumerate(self.shortPosition) if val.vtSymbol == trade.vtSymbol] if len(pop_indexs) < 1: self.writeCtaError(u'异常,没有对应symbol:{0}的空单持仓'.format(trade.vtSymbol)) @@ -2263,9 +2832,10 @@ class BacktestingEngine(object): # 开空volume,不大于平仓volume if coverVolume >= entryTrade.volume: - self.writeCtaLog(u'coverVolume:{0} >= entryTrade.volume:{1}'.format(coverVolume, entryTrade.volume)) + self.writeCtaLog(u'开空volume,不大于平仓volume, coverVolume:{} ,先平::{}'.format(coverVolume, entryTrade.volume)) coverVolume = coverVolume - entryTrade.volume - + if coverVolume>0: + self.writeCtaLog(u'剩余待平数量:{}'.format(coverVolume)) self.output(u'{0}空平:{1},{2}'.format(entryTrade.vtSymbol, entryTrade.volume, trade.price)) self.writeCtaLog(u'{0}空平:{1},{2}'.format(entryTrade.vtSymbol, entryTrade.volume, trade.price)) @@ -2293,7 +2863,7 @@ class BacktestingEngine(object): t['Commission'] = result.commission self.exportTradeList.append(t) - msg = u'Gid:{0} {1}[{2}:开空tid={3}:{4}]-[{5}.平空tid={6},{7},vol:{8}],净盈亏:{9},手续费:{10}'\ + msg = u'Gid:{0} {1}[{2}:开空tid={3}:{4}]-[{5}.平空tid={6},{7},vol:{8}],净盈亏pnl={9},手续费:{10}'\ .format(gId, entryTrade.vtSymbol, entryTrade.tradeTime, shortid, entryTrade.price, trade.tradeTime, tradeid, trade.price, entryTrade.volume, result.pnl,result.commission) @@ -2305,9 +2875,9 @@ class BacktestingEngine(object): if coverVolume > 0: # 属于组合 gr = copy.deepcopy(result) - else: # 删除平空交易单, + self.writeCtaLog(u'删除平空交易单,tradeID:'.format(trade.tradeID)) del self.tradeDict[trade.tradeID] else: @@ -2319,9 +2889,11 @@ class BacktestingEngine(object): # 所有仓位平完 if coverVolume == 0: + self.writeCtaLog(u'所有平空仓位撮合完毕') gr.volume = abs(trade.volume) #resultDict[entryTrade.dt] = gr # 删除平空交易单, + self.writeCtaLog(u'删除平空交易单:{}'.format(trade.tradeID)) del self.tradeDict[trade.tradeID] # 开空volume,大于平仓volume,需要更新减少tradeDict的数量。 @@ -2353,7 +2925,7 @@ class BacktestingEngine(object): t['Commission'] = result.commission self.exportTradeList.append(t) - msg = u'Gid:{0} {1}[{2}:开空tid={3}:{4}]-[{5}.平空tid={6},{7},vol:{8}],净盈亏:{9},手续费:{10}'\ + msg = u'Gid:{0} {1}[{2}:开空tid={3}:{4}]-[{5}.平空tid={6},{7},vol:{8}],净盈亏pnl={9},手续费:{10}'\ .format(gId, entryTrade.vtSymbol, entryTrade.tradeTime, shortid, entryTrade.price, trade.tradeTime, tradeid, trade.price, coverVolume, result.pnl,result.commission) @@ -2362,7 +2934,10 @@ class BacktestingEngine(object): # 更新(减少)开仓单的volume,重新推进开仓单列表中 entryTrade.volume = shortVolume + self.writeCtaLog(u'更新(减少)开仓单的volume,重新推进开仓单列表中:{}'.format(entryTrade.volume)) self.shortPosition.append(entryTrade) + cur_short_pos_list = [s_pos.volume for s_pos in self.shortPosition] + self.writeCtaLog(u'当前空单:{}'.format(cur_short_pos_list)) coverVolume = 0 resultDict.append(result) @@ -2443,7 +3018,7 @@ class BacktestingEngine(object): t['Commission'] = result.commission self.exportTradeList.append(t) - msg = u'Gid:{0} {1}[{2}:开多tid={3}:{4}]-[{5}.平多tid={6},{7},vol:{8}],净盈亏:{9},手续费:{10}'\ + msg = u'Gid:{0} {1}[{2}:开多tid={3}:{4}]-[{5}.平多tid={6},{7},vol:{8}],净盈亏pnl={9},手续费:{10}'\ .format(gId, entryTrade.vtSymbol, entryTrade.tradeTime, longid, entryTrade.price, trade.tradeTime, tradeid, trade.price, @@ -2503,7 +3078,7 @@ class BacktestingEngine(object): t['Commission'] = result.commission self.exportTradeList.append(t) - msg = u'Gid:{0} {1}[{2}:开多tid={3}:{4}]-[{5}.平多tid={6},{7},vol:{8}],净盈亏:{9},手续费:{10}'\ + msg = u'Gid:{0} {1}[{2}:开多tid={3}:{4}]-[{5}.平多tid={6},{7},vol:{8}],净盈亏pnl={9},手续费:{10}'\ .format(gId, entryTrade.vtSymbol,entryTrade.tradeTime, longid, entryTrade.price, trade.tradeTime, tradeid, trade.price, sellVolume, result.pnl, result.commission) self.output(msg) @@ -2625,6 +3200,8 @@ class BacktestingEngine(object): dict['maxCapital'] = m long_list = [] today_margin = 0 + today_margin_long = 0 + today_margin_short = 0 long_pos_occupy_money = 0 short_pos_occupy_money = 0 @@ -2636,15 +3213,16 @@ class BacktestingEngine(object): else: benchmark = 1 + positionMsg = "" for longpos in self.longPosition: symbol = '-' if longpos.vtSymbol == EMPTY_STRING else longpos.vtSymbol # 计算持仓浮盈浮亏/占用保证金 pos_margin = 0 - if self.last_leg1_tick is not None and self.last_leg1_tick.vtSymbol == symbol: + if self.last_leg1_tick is not None and (self.last_leg1_tick.vtSymbol == symbol or self.last_leg1_tick.symbol == symbol): pos_margin = (self.last_leg1_tick.lastPrice - longpos.price) * longpos.volume * self.size long_pos_occupy_money += self.last_leg1_tick.lastPrice * abs(longpos.volume) * self.size * self.margin_rate - elif self.last_leg2_tick is not None and self.last_leg2_tick.vtSymbol == symbol: + elif self.last_leg2_tick is not None and (self.last_leg2_tick.vtSymbol == symbol or self.last_leg2_tick.symbol == symbol): pos_margin = (self.last_leg2_tick.lastPrice - longpos.price) * longpos.volume * self.size long_pos_occupy_money += self.last_leg2_tick.lastPrice * abs(longpos.volume) * self.size * self.margin_rate @@ -2654,18 +3232,20 @@ class BacktestingEngine(object): longpos.volume) * self.size * self.margin_rate today_margin += pos_margin + today_margin_long += pos_margin long_list.append({'symbol': symbol, 'direction':'long','price':longpos.price,'volume':longpos.volume,'margin':pos_margin}) + positionMsg += "{},long,p={},v={},m={};".format(symbol,longpos.price,longpos.volume,pos_margin) short_list = [] for shortpos in self.shortPosition: symbol = '-' if shortpos.vtSymbol == EMPTY_STRING else shortpos.vtSymbol # 计算持仓浮盈浮亏/占用保证金 pos_margin = 0 - if self.last_leg1_tick is not None and self.last_leg1_tick.vtSymbol == symbol: + if self.last_leg1_tick is not None and (self.last_leg1_tick.vtSymbol == symbol or self.last_leg1_tick.symbol == symbol): pos_margin = (shortpos.price - self.last_leg1_tick.lastPrice) * shortpos.volume * self.size short_pos_occupy_money += self.last_leg1_tick.lastPrice * abs(shortpos.volume) * self.size * self.margin_rate - elif self.last_leg2_tick is not None and self.last_leg2_tick.vtSymbol == symbol: + elif self.last_leg2_tick is not None and (self.last_leg2_tick.vtSymbol == symbol or self.last_leg2_tick.symbol == symbol): pos_margin = (shortpos.price - self.last_leg2_tick.lastPrice) * shortpos.volume * self.size short_pos_occupy_money += self.last_leg2_tick.lastPrice * abs( shortpos.volume) * self.size * self.margin_rate elif self.last_bar is not None: @@ -2673,8 +3253,10 @@ class BacktestingEngine(object): short_pos_occupy_money += self.last_bar.close * abs( shortpos.volume) * self.size * self.margin_rate today_margin += pos_margin + today_margin_short += pos_margin short_list.append({'symbol': symbol, 'direction': 'short', 'price': shortpos.price, 'volume': shortpos.volume, 'margin': pos_margin}) + positionMsg += "{},short,p={},v={},m={};".format(symbol,shortpos.price,shortpos.volume,pos_margin) dict['net'] = c + today_margin dict['rate'] = (c + today_margin )/ self.initCapital @@ -2686,6 +3268,20 @@ class BacktestingEngine(object): dict['occupyRate'] = dict['occupyMoney'] / dict['capital'] dict['commission'] = commission dict['benchmark'] = benchmark + dict['todayMarginLong'] = today_margin_long + dict['todayMarginShort'] = today_margin_short + self.last_leg1_tick = None + if self.tick is not None: + dict['lastPrice'] = self.tick.lastPrice + elif self.last_leg1_tick is not None: + dict['lastPrice'] = self.last_leg1_tick.lastPrice + elif self.last_leg2_tick is not None: + dict['lastPrice'] = self.last_leg2_tick.lastPrice + elif self.last_bar is not None: + dict['lastPrice'] = self.last_bar.close + else: + dict['lastPrice'] = self.dailyList[-1]['lastPrice'] + self.dailyList.append(dict) # 更新每日浮动净值 @@ -2700,42 +3296,128 @@ class BacktestingEngine(object): self.daily_max_drawdown_rate = drawdown_rate self.max_drowdown_rate_time = dict['date'] + self.writeCtaLog(u'DEBUG---: savingDailyData, {}: lastPrice={}, net={}, capital={} max={} margin={} commission={} longPos={} shortPos={}, {}'.format( + dict['date'], dict['lastPrice'], dict['net'], c, m, today_margin, commission, len(long_list), len(short_list), positionMsg)) + # ---------------------------------------------------------------------- - def writeWenHuaSignal(self, filehandle, count, bardatetime, price, text): + def writeWenHuaSignal(self, filehandle, count, bardatetime, price, text, mask=52): """ 输出到文华信号 :param filehandle: :param count: :param bardatetime: + :param price: :param text: + :param mask: bit8~1 = [(H2)(H1)(M30)(M15)(M10)(M5)(M3)(M1)], e.g. 52 means M30, M15, M5 :return: """ # 文华信号 - barDate = bardatetime.strftime('%Y%m%d') + bardatetime2 = bardatetime + if bardatetime.hour >= 21: + if bardatetime.isoweekday() == 5: + # 星期五=》星期一 + bardatetime2 = bardatetime + timedelta(days=3) + else: + # 第二天 + bardatetime2 = bardatetime + timedelta(days=1) + elif bardatetime.hour < 8 and bardatetime.isoweekday() == 6: + # 星期六=>星期一 + bardatetime2 = bardatetime + timedelta(days=2) + barDate = bardatetime2.strftime('%Y%m%d') barTime = bardatetime.strftime('%H%M') + + isFirst = False + prefixMsg = '(AA{}'.format(count) outputMsg = 'AA{}:=DATE={};\n'.format(count, barDate[2:]) filehandle.write(outputMsg) - outputMsg = 'BB{}:=PERIOD=1&&TIME={};\n'.format(count, barDate[2:], barTime) - #filehandle.write(outputMsg) - outputMsg = 'CC{}:=PERIOD=2&&TIME>={}&&TIME<={}+3;\n'.format(count, barTime, barTime) - #filehandle.write(outputMsg) - outputMsg = 'DD{}:=PERIOD=3&&TIME>={}&&TIME<={}+5;\n'.format(count, barTime, barTime) - #filehandle.write(outputMsg) - outputMsg = 'EE{}:=PERIOD=4&&TIME>={}&&TIME<={}+10;\n'.format(count, barTime, barTime) - #filehandle.write(outputMsg) - outputMsg = 'FF{}:=PERIOD=5&&TIME>={}&&TIME<={}+30;\n'.format(count, barTime, barTime) + barTime = bardatetime.strftime('%H%M') + if mask & 1 > 0: # Min1 + outputMsg = 'BB{}:=PERIOD=1&&TIME={};\n'.format(count, barDate[2:], barTime) + filehandle.write(outputMsg) + if isFirst is False: + prefixMsg += ' AND (' + isFirst = True + prefixMsg += 'BB{}'.format(count) + if mask & 2 > 0: # Min3 + barTimeBegin = (bardatetime - timedelta(minutes=3)).strftime('%H%M') + outputMsg = 'CC{}:=PERIOD=2&&TIME>{}&&TIME<={};\n'.format(count, barTimeBegin, barTime) + filehandle.write(outputMsg) + if isFirst is False: + prefixMsg += ' AND (' + isFirst = True + else: + prefixMsg += ' OR ' + prefixMsg += 'CC{}'.format(count) + if mask & 4 > 0: # Min5 + barTimeBegin = (bardatetime - timedelta(minutes=5)).strftime('%H%M') + outputMsg = 'DD{}:=PERIOD=3&&TIME>{}&&TIME<={};\n'.format(count, barTimeBegin, barTime) + filehandle.write(outputMsg) + if isFirst is False: + prefixMsg += ' AND (' + isFirst = True + else: + prefixMsg += ' OR ' + prefixMsg += 'DD{}'.format(count) + if mask & 8 > 0: # Min10 + barTimeBegin = (bardatetime - timedelta(minutes=10)).strftime('%H%M') + outputMsg = 'EE{}:=PERIOD=4&&TIME>{}&&TIME<={};\n'.format(count, barTimeBegin, barTime) + filehandle.write(outputMsg) + if isFirst is False: + prefixMsg += ' AND (' + isFirst = True + else: + prefixMsg += ' OR ' + prefixMsg += 'EE{}'.format(count) + if mask & 16 > 0: # Min15 + barTimeBegin = (bardatetime - timedelta(minutes=15)).strftime('%H%M') + outputMsg = 'FF{}:=PERIOD=5&&TIME>{}&&TIME<={};\n'.format(count, barTimeBegin, barTime) + filehandle.write(outputMsg) + if isFirst is False: + prefixMsg += ' AND (' + isFirst = True + else: + prefixMsg += ' OR ' + prefixMsg += 'FF{}'.format(count) + if mask & 32 > 0: # Min30 + barTimeBegin = (bardatetime - timedelta(minutes=30)).strftime('%H%M') + if bardatetime.hour == 10: + if bardatetime.minute >30 and bardatetime.minute < 45: + barTimeBegin = (bardatetime - timedelta(minutes=45)).strftime('%H%M') + elif bardatetime.hour == 13: + if bardatetime.minute < 45: + barTimeBegin = (bardatetime - timedelta(minutes=150)).strftime('%H%M') + outputMsg = 'GG{}:=PERIOD=6&&TIME>{}&&TIME<={};\n'.format(count, barTimeBegin, barTime) + filehandle.write(outputMsg) + if isFirst is False: + prefixMsg += ' AND (' + isFirst = True + else: + prefixMsg += ' OR ' + prefixMsg += 'GG{}'.format(count) + if mask & 64 > 0: # Hour1 + outputMsg = 'HH{}:=PERIOD=7&&TIME>{}-59&&TIME<={};\n'.format(count, barTime, barTime) + filehandle.write(outputMsg) + if isFirst is False: + prefixMsg += ' AND (' + isFirst = True + else: + prefixMsg += ' OR ' + prefixMsg += 'HH{}'.format(count) + if mask & 128 > 0: # Hour2 + outputMsg = 'II{}:=PERIOD=8;\n'.format(count) + filehandle.write(outputMsg) + if isFirst is False: + prefixMsg += ' AND (' + isFirst = True + else: + prefixMsg += ' OR ' + prefixMsg += 'II{}'.format(count) + if isFirst is True: + prefixMsg += ')' + + outputMsg = 'DRAWICON' + prefixMsg + ', {}, \'ICO14\');\n'.format(price) filehandle.write(outputMsg) - outputMsg = 'GG{}:=PERIOD=6&&TIME>={}&&TIME<={}+60;\n'.format(count, barTime, barTime) - filehandle.write(outputMsg) - outputMsg = 'HH{}:=PERIOD=7&&TIME>={}&&TIME<={}+120;\n'.format(count, barTime, barTime) - filehandle.write(outputMsg) - outputMsg = 'II{}:=PERIOD=8;\n'.format(count) - filehandle.write(outputMsg) - outputMsg = 'DRAWICON(AA{} AND ( FF{} OR GG{} OR HH{} OR II{}), L, \'ICO14\');\n'.format( - count, count, count, count, count) - filehandle.write(outputMsg) - outputMsg = 'DRAWTEXT(AA{} AND (FF{} OR GG{} OR HH{} OR II{}), {}, \'{}\');\n'.format( - count, count, count, count, count, price, text) + outputMsg = 'DRAWTEXT' + prefixMsg + ', H, \'{}\');\n'.format(text) filehandle.write(outputMsg) filehandle.flush() @@ -2830,7 +3512,7 @@ class BacktestingEngine(object): t['Commission'] = result.commission self.exportTradeList.append(t) - self.writeCtaLog(u'{9}@{6} [{7}:开空{0},short:{1}]-[{8}:平空{2},cover:{3},vol:{4}],净盈亏:{5}' + self.writeCtaLog(u'{9}@{6} [{7}:开空{0},short:{1}]-[{8}:平空{2},cover:{3},vol:{4}],净盈亏pnl={5}' .format(entryTrade.tradeTime, entryTrade.price, trade.tradeTime, trade.price, tradeUnit, result.pnl, i, shortid, tradeid, gId)) @@ -2899,7 +3581,7 @@ class BacktestingEngine(object): t['Commission'] = result.commission self.exportTradeList.append(t) - self.writeCtaLog(u'{9}@{6} [{7}:开多{0},buy:{1}]-[{8}.平多{2},sell:{3},vol:{4}],净盈亏:{5}' + self.writeCtaLog(u'{9}@{6} [{7}:开多{0},buy:{1}]-[{8}.平多{2},sell:{3},vol:{4}],净盈亏pnl={5}' .format(entryTrade.tradeTime, entryTrade.price, trade.tradeTime,trade.price, tradeUnit, result.pnl, i, longid, tradeid, gId)) @@ -2998,67 +3680,92 @@ class BacktestingEngine(object): for row in self.exportTradeList: writer.writerow(row) + # 交易记录生成文华对应的公式 if self.export_wenhua_signal: - wh_records = OrderedDict() + filename = os.path.abspath(os.path.join(self.get_logs_path(), + '{}_WenHua_{}.txt'.format(s, datetime.now().strftime('%Y%m%d_%H%M')))) + self.writeCtaLog(u'save trade records for WenHua:{}'.format(filename)) + wenhuaSingalCount = 0 + wenhuaSignalFile = open(filename, mode='w') + for t in self.exportTradeList: if t['Direction'] is 'Long': - k = '{}_{}_{}'.format(t['OpenTime'], 'Buy', t['OpenPrice']) # 生成文华用的指标信号 - v = {'time': datetime.strptime(t['OpenTime'], '%Y-%m-%d %H:%M:%S'), 'price':t['OpenPrice'], 'action': 'Buy', 'volume':t['Volume']} - r = wh_records.get(k,None) - if r is not None: - r['volume'] += t['Volume'] - else: - wh_records[k] = v - - k = '{}_{}_{}'.format(t['CloseTime'], 'Sell', t['ClosePrice']) - # 生成文华用的指标信号 - v = {'time': datetime.strptime(t['CloseTime'], '%Y-%m-%d %H:%M:%S'), 'price': t['ClosePrice'], 'action': 'Sell', 'volume': t['Volume']} - r = wh_records.get(k, None) - if r is not None: - r['volume'] += t['Volume'] - else: - wh_records[k] = v - + msg = 'Buy@{},{}'.format(t['OpenPrice'], t['Volume']) + self.writeWenHuaSignal(wenhuaSignalFile, wenhuaSingalCount, datetime.strptime(t['OpenTime'], '%Y-%m-%d %H:%M:%S'), t['OpenPrice'], msg) + wenhuaSingalCount += 1 + msg = 'Sell@{},{} ({})'.format(t['ClosePrice'], t['Volume'], round(t['Profit'])) + self.writeWenHuaSignal(wenhuaSignalFile, wenhuaSingalCount, datetime.strptime(t['CloseTime'], '%Y-%m-%d %H:%M:%S'), t['ClosePrice'], msg) + wenhuaSingalCount += 1 else: - k = '{}_{}_{}'.format(t['OpenTime'], 'Short', t['OpenPrice']) # 生成文华用的指标信号 - v = {'time': datetime.strptime(t['OpenTime'], '%Y-%m-%d %H:%M:%S'), 'price': t['OpenPrice'], 'action': 'Short', 'volume': t['Volume']} - r = wh_records.get(k, None) - if r is not None: - r['volume'] += t['Volume'] - else: - wh_records[k] = v - k = '{}_{}_{}'.format(t['CloseTime'], 'Cover', t['ClosePrice']) - # 生成文华用的指标信号 - v = {'time': datetime.strptime(t['CloseTime'], '%Y-%m-%d %H:%M:%S'), 'price': t['ClosePrice'], 'action': 'Cover', 'volume': t['Volume']} - r = wh_records.get(k, None) - if r is not None: - r['volume'] += t['Volume'] - else: - wh_records[k] = v - - branchs = 0 - count = 0 - wh_signal_file = None - for r in list(wh_records.values()): - if count % 200 == 0: - if wh_signal_file is not None: - wh_signal_file.close() - - # 交易记录生成文华对应的公式 - filename = os.path.abspath(os.path.join(self.get_logs_path(), - '{}_WenHua_{}_{}.csv'.format(s, datetime.now().strftime('%Y%m%d_%H%M'), branchs))) - branchs += 1 - self.writeCtaLog(u'save trade records for WenHua:{}'.format(filename)) - - wh_signal_file = open(filename, mode='w') - - count += 1 - if wh_signal_file is not None: - self.writeWenHuaSignal(filehandle=wh_signal_file, count=count, bardatetime=r['time'],price=r['price'], text='{}({})'.format(r['action'],r['volume'])) - if wh_signal_file is not None: - wh_signal_file.close() + msg = 'Short@{},{}'.format(t['OpenPrice'], t['Volume']) + self.writeWenHuaSignal(wenhuaSignalFile, wenhuaSingalCount, datetime.strptime(t['OpenTime'], '%Y-%m-%d %H:%M:%S'), t['OpenPrice'], msg) + wenhuaSingalCount += 1 + msg = 'Cover@{},{} ({})'.format(t['ClosePrice'], t['Volume'], round(t['Profit'])) + self.writeWenHuaSignal(wenhuaSignalFile, wenhuaSingalCount, datetime.strptime(t['CloseTime'], '%Y-%m-%d %H:%M:%S'), t['ClosePrice'], msg) + wenhuaSingalCount += 1 + wenhuaSignalFile.close() +# wh_records = OrderedDict() +# for t in self.exportTradeList: +# if t['Direction'] is 'Long': +# k = '{}_{}_{}'.format(t['OpenTime'], 'Buy', t['OpenPrice']) +# # 生成文华用的指标信号 +# v = {'time': datetime.strptime(t['OpenTime'], '%Y-%m-%d %H:%M:%S'), 'price':t['OpenPrice'], 'action': 'Buy', 'volume':t['Volume']} +# r = wh_records.get(k,None) +# if r is not None: +# r['volume'] += t['Volume'] +# else: +# wh_records[k] = v +# +# k = '{}_{}_{}'.format(t['CloseTime'], 'Sell', t['ClosePrice']) +# # 生成文华用的指标信号 +# v = {'time': datetime.strptime(t['CloseTime'], '%Y-%m-%d %H:%M:%S'), 'price': t['ClosePrice'], 'action': 'Sell', 'volume': t['Volume']} +# r = wh_records.get(k, None) +# if r is not None: +# r['volume'] += t['Volume'] +# else: +# wh_records[k] = v +# +# else: +# k = '{}_{}_{}'.format(t['OpenTime'], 'Short', t['OpenPrice']) +# # 生成文华用的指标信号 +# v = {'time': datetime.strptime(t['OpenTime'], '%Y-%m-%d %H:%M:%S'), 'price': t['OpenPrice'], 'action': 'Short', 'volume': t['Volume']} +# r = wh_records.get(k, None) +# if r is not None: +# r['volume'] += t['Volume'] +# else: +# wh_records[k] = v +# k = '{}_{}_{}'.format(t['CloseTime'], 'Cover', t['ClosePrice']) +# # 生成文华用的指标信号 +# v = {'time': datetime.strptime(t['CloseTime'], '%Y-%m-%d %H:%M:%S'), 'price': t['ClosePrice'], 'action': 'Cover', 'volume': t['Volume']} +# r = wh_records.get(k, None) +# if r is not None: +# r['volume'] += t['Volume'] +# else: +# wh_records[k] = v +# +# branchs = 0 +# count = 0 +# wh_signal_file = None +# for r in list(wh_records.values()): +# if count % 200 == 0: +# if wh_signal_file is not None: +# wh_signal_file.close() +# +# # 交易记录生成文华对应的公式 +# filename = os.path.abspath(os.path.join(self.get_logs_path(), +# '{}_WenHua_{}_{}.csv'.format(s, datetime.now().strftime('%Y%m%d_%H%M'), branchs))) +# branchs += 1 +# self.writeCtaLog(u'save trade records for WenHua:{}'.format(filename)) +# +# wh_signal_file = open(filename, mode='w') +# +# count += 1 +# if wh_signal_file is not None: +# self.writeWenHuaSignal(filehandle=wh_signal_file, count=count, bardatetime=r['time'],price=r['price'], text='{}({})'.format(r['action'],r['volume'])) +# if wh_signal_file is not None: +# wh_signal_file.close() # 导出每日净值记录表 if not self.dailyList: @@ -3066,19 +3773,20 @@ class BacktestingEngine(object): if self.daily_report_name == EMPTY_STRING: csvOutputFile2 = os.path.abspath(os.path.join(self.get_logs_path(), - 'DailyList_{0}.csv'.format(datetime.now().strftime('%Y%m%d_%H%M')))) + '{}_DailyList_{}.csv'.format(s, datetime.now().strftime('%Y%m%d_%H%M')))) else: csvOutputFile2 = self.daily_report_name self.writeCtaLog(u'save daily records to:{}'.format(csvOutputFile2)) csvWriteFile2 = open(csvOutputFile2, 'w', encoding='utf8',newline='') - fieldnames = ['date', 'capital','net', 'maxCapital','rate', 'commission', 'longMoney','shortMoney','occupyMoney','occupyRate','longPos','shortPos','benchmark'] + fieldnames = ['date','lastPrice','capital','net','maxCapital','rate','commission','longMoney','shortMoney','occupyMoney','occupyRate','longPos','shortPos','todayMarginLong','todayMarginShort','benchmark'] writer2 = csv.DictWriter(f=csvWriteFile2, fieldnames=fieldnames, dialect='excel') writer2.writeheader() for row in self.dailyList: writer2.writerow(row) + return def getResult(self): # 返回回测结果 @@ -3088,7 +3796,7 @@ class BacktestingEngine(object): d['maxCapital'] = self.maxNetCapital # 取消原 maxCapital if len(self.pnlList) == 0: - return {} + return {}, [], [] d['maxPnl'] = max(self.pnlList) d['minPnl'] = min(self.pnlList) @@ -3120,7 +3828,22 @@ class BacktestingEngine(object): d['averageLosing'] = averageLosing d['profitLossRatio'] = profitLossRatio - return d + # 计算Sharp + if not self.dailyList: + return + + capitalNetList = [] + capitalList = [] + for row in self.dailyList: + capitalNetList.append(row['net']) + capitalList.append(row['capital']) + + capital = pd.Series(capitalNetList) + log_returns = np.log(capital).diff().fillna(0) + sharpe = (log_returns.mean() * 252) / (log_returns.std() * np.sqrt(252)) + d['sharpe'] = sharpe + + return d, capitalNetList, capitalList #---------------------------------------------------------------------- def showBacktestingResult(self): @@ -3128,7 +3851,7 @@ class BacktestingEngine(object): if self.calculateMode != self.REALTIME_MODE: self.calculateBacktestingResult() - d = self.getResult() + d, dailyNetCapital, dailyCapital = self.getResult() if len(d) == 0: self.output(u'无交易结果') @@ -3166,11 +3889,12 @@ class BacktestingEngine(object): self.writeCtaNotification(u'平均每笔滑点成本:\t%s' %formatNumber(d['totalSlippage']/d['totalResult'])) self.writeCtaNotification(u'平均每笔佣金:\t%s' %formatNumber(d['totalCommission']/d['totalResult'])) - + self.writeCtaNotification(u'Sharpe Ratio:\t%s' % formatNumber(d['sharpe'])) + # 绘图 - """ import matplotlib import matplotlib.pyplot as plt + from matplotlib.ticker import MultipleLocator, FormatStrFormatter import numpy as np matplotlib.rcParams['figure.figsize'] = (20.0, 10.0) @@ -3180,34 +3904,90 @@ class BacktestingEngine(object): except ImportError: pass - pCapital = plt.subplot(4, 1, 1) - pCapital.set_ylabel("capital") - pCapital.plot(d['capitalList'], color='r', lw=0.8) + # 是否显示每日资金曲线 + isPlotDaily = False # DEBUG + #isPlotDaily = True + capitalStr = '' + if isPlotDaily == True: + daily_df = pd.DataFrame(self.dailyList) + daily_df = daily_df.set_index('date') - plt.title(u'{}~{},{} backtest result '.format(self.startDate, self.endDate, self.strategy_name)) + pCapital = plt.subplot(4, 1, 1) + pCapital.set_ylabel("trade capital") + pCapital.plot(d['capitalList'], color='r', lw=0.8) + plt.title(u'{}: {}~{}({}) NetCapital={}({}), #Trading={}({}/day), TotalCommission={}, MaxLots={}({}), MDD={}%'.format( + self.symbol, + self.startDate, self.endDate, len(dailyNetCapital), + dailyNetCapital[-1], min(d['drawdownList']), + d['totalResult'], int(d['totalResult']/len(dailyNetCapital)), + d['totalCommission'], + d['maxVolume'], max(daily_df['occupyMoney']), + self.daily_max_drawdown_rate)) + pCapital.grid() + capitalStr = '{}.{}'.format(round(dailyNetCapital[-1]), round(min(d['drawdownList']))) - pDD = plt.subplot(4, 1, 2) - pDD.set_ylabel("DD") - pDD.bar(range(len(d['drawdownList'])), d['drawdownList'], color='g') - - pPnl = plt.subplot(4, 1, 3) - pPnl.set_ylabel("pnl") - pPnl.hist(d['pnlList'], bins=50, color='c') + pDailyCapital = plt.subplot(4, 1, 3) + pDailyCapital.set_ylabel("daily capital") + pDailyCapital.plot(dailyCapital, color='b', lw=0.8, label='Capital') + pDailyCapital.plot(dailyNetCapital, color='r', lw=1, label='NetCapital') + # Change the label of X-Axes to date + xt = pDailyCapital.get_xticks() + interval = len(dailyNetCapital) / 10 + interval = round(interval) -1 + xt3 = list(range(-10, len(dailyNetCapital), interval)) + xt2 = [daily_df.index[int(i)] for i in xt3[1:]] + xt2.insert(0,'') + xt2.append('') + pDailyCapital.set_xticks(xt3) + pDailyCapital.set_xticklabels(xt2) + pDailyCapital.grid() + pDailyCapital.legend() + pLastPrice = plt.subplot(4, 1, 2) + pLastPrice.set_ylabel("daily lastprice") + pLastPrice.plot(daily_df['lastPrice'], color='y', lw=1, label='Price') + pLastPrice.set_xticks(xt3) + pLastPrice.set_xticklabels(xt2) + pLastPrice.grid() + pLastPrice.legend() + + pOccupyRate = plt.subplot(4, 1, 4) + pOccupyRate.set_ylabel("occupyMoney") + index = np.arange(len(daily_df['occupyMoney'])) + pOccupyRate.bar(index, daily_df['occupyMoney'], 0.4, color='b') + pOccupyRate.set_xticks(xt3) + pOccupyRate.set_xticklabels(xt2) + pOccupyRate.grid() + + else: + pCapital = plt.subplot(4, 1, 1) + pCapital.set_ylabel("capital") + pCapital.plot(d['capitalList'], color='r', lw=0.8) + + plt.title(u'{}~{},{} backtest result '.format(self.startDate, self.endDate, self.strategy_name)) + + pDD = plt.subplot(4, 1, 2) + pDD.set_ylabel("DD") + pDD.bar(range(len(d['drawdownList'])), d['drawdownList'], color='g') + + pPnl = plt.subplot(4, 1, 3) + pPnl.set_ylabel("pnl") + pPnl.hist(d['pnlList'], bins=50, color='c') plt.tight_layout() #plt.xticks(xindex, tradeTimeIndex, rotation=30) # 旋转15 fig_file_name = os.path.abspath(os.path.join(self.get_logs_path(), - '{}_plot_{}.png'.format(self.strategy_name, - datetime.now().strftime('%Y%m%d_%H%M')))) + '{}_plot_{}_{}.png'.format(self.strategy_name, + datetime.now().strftime('%Y%m%d_%H%M'), capitalStr))) fig = plt.gcf() fig.savefig(fig_file_name) self.output (u'图表保存至:{0}'.format(fig_file_name)) #plt.show() - """ + plt.close() + #---------------------------------------------------------------------- def putStrategyEvent(self, name): """发送策略更新事件,回测中忽略""" @@ -3245,7 +4025,8 @@ class BacktestingEngine(object): self.output(u'优化结果:') for result in resultList: self.output(u'%s: %s' %(result[0], result[1])) - return result + + return resultList #---------------------------------------------------------------------- def clearBacktestingResult(self): @@ -3307,6 +4088,20 @@ class BacktestingEngine(object): newPrice = round(price/priceTick, 0) * priceTick return newPrice + + def roundToVolumeTick(self,volumeTick,volume): + if volumeTick == 0: + return volume + newVolume = round(volume / volumeTick, 0) * volumeTick + if isinstance(volumeTick,float): + v_exponent = decimal.Decimal(str(newVolume)) + vt_exponent = decimal.Decimal(str(volumeTick)) + if abs(v_exponent.as_tuple().exponent) > abs(vt_exponent.as_tuple().exponent): + newVolume = round(newVolume, ndigits=abs(vt_exponent.as_tuple().exponent)) + newVolume = float(str(newVolume)) + + return newVolume + def getTradingDate(self, dt=None): """ 根据输入的时间,返回交易日的日期 @@ -3414,7 +4209,6 @@ class OptimizationSetting(object): """设置优化目标字段""" self.optimizeTarget = target - #---------------------------------------------------------------------- def formatNumber(n): """格式化数字到字符串""" diff --git a/vnpy/trader/app/ctaStrategy/ctaBase.py b/vnpy/trader/app/ctaStrategy/ctaBase.py index fce7de08..a6710445 100644 --- a/vnpy/trader/app/ctaStrategy/ctaBase.py +++ b/vnpy/trader/app/ctaStrategy/ctaBase.py @@ -56,7 +56,7 @@ ENGINETYPE_TRADING = 'trading' # 实盘 # CTA引擎中涉及的数据类定义 -from vnpy.trader.vtConstant import * +from vnpy.trader.vtConstant import EMPTY_STRING,EMPTY_UNICODE,EMPTY_FLOAT,EMPTY_INT,COLOR_EQUAL ######################################################################## class StopOrder(object): diff --git a/vnpy/trader/app/ctaStrategy/ctaEngine.py b/vnpy/trader/app/ctaStrategy/ctaEngine.py index 5f7ac80f..07d9d4ae 100644 --- a/vnpy/trader/app/ctaStrategy/ctaEngine.py +++ b/vnpy/trader/app/ctaStrategy/ctaEngine.py @@ -266,10 +266,10 @@ class CtaEngine(object): self.writeCtaLog(msg) # 发送微信 - try: - sendWeChatMsg(msg, target=self.ctaEngine.mainEngine.gatewayDetailList[0]['gatewayName']) - except: - pass + #try: + # sendWeChatMsg(msg, target=self.mainEngine.gatewayDetailList[0]['gatewayName']) + #except: + # pass return vtOrderID @@ -294,11 +294,11 @@ class CtaEngine(object): self.mainEngine.cancelOrder(req, order.gatewayName) # 发送微信 - try: - msg = u'发送撤单指令,%s, %s,%s' % (order.symbol, order.orderID, order.gatewayName) - sendWeChatMsg(msg, target=self.ctaEngine.mainEngine.gatewayDetailList[0]['gatewayName']) - except: - pass + #try: + # msg = u'发送撤单指令,%s, %s,%s' % (order.symbol, order.orderID, order.gatewayName) + # sendWeChatMsg(msg, target=self.mainEngine.gatewayDetailList[0]['gatewayName']) + #except: + # pass else: if order.status == STATUS_ALLTRADED: self.writeCtaLog(u'委托单({0}已执行,无法撤销'.format(vtOrderID)) @@ -343,11 +343,11 @@ class CtaEngine(object): self.mainEngine.cancelOrder(req, order.gatewayName) # 发送微信 - try: - msg = u'撤销所有单,{}'.format(symbol) - sendWeChatMsg(msg, target=self.ctaEngine.mainEngine.gatewayDetailList[0]['gatewayName']) - except: - pass + #try: + # msg = u'撤销所有单,{}'.format(symbol) + # sendWeChatMsg(msg, target=self.mainEngine.gatewayDetailList[0]['gatewayName']) + #except: + # pass # ---------------------------------------------------------------------- def sendStopOrder(self, vtSymbol, orderType, price, volume, strategy): @@ -388,10 +388,10 @@ class CtaEngine(object): self.writeCtaLog(msg) # 发送微信 - try: - sendWeChatMsg(msg, target=self.ctaEngine.mainEngine.gatewayDetailList[0]['gatewayName']) - except: - pass + #try:# + # sendWeChatMsg(msg, target=self.mainEngine.gatewayDetailList[0]['gatewayName']) + #except: + # pass return stopOrderID # ---------------------------------------------------------------------- @@ -408,19 +408,19 @@ class CtaEngine(object): self.writeCtaLog(u'撤销停止单:{0}成功.'.format(stopOrderID)) # 发送微信 - try: - sendWeChatMsg(u'撤销停止单:{0}成功.'.format(stopOrderID), target=self.ctaEngine.mainEngine.gatewayDetailList[0]['gatewayName']) - except: - pass + #try: + # sendWeChatMsg(u'撤销停止单:{0}成功.'.format(stopOrderID), target=self.mainEngine.gatewayDetailList[0]['gatewayName']) + #except: + # pass return True else: self.writeCtaLog(u'撤销停止单:{0}失败,不存在Id.'.format(stopOrderID)) # 发送微信 - try: - sendWeChatMsg(u'撤销停止单:{0}失败,不存在Id.'.format(stopOrderID), target=self.ctaEngine.mainEngine.gatewayDetailList[0]['gatewayName']) - except: - pass + #try: + # sendWeChatMsg(u'撤销停止单:{0}失败,不存在Id.'.format(stopOrderID), target=self.mainEngine.gatewayDetailList[0]['gatewayName']) + #except: + # pass return False # ---------------------------------------------------------------------- diff --git a/vnpy/trader/app/ctaStrategy/language/__init__.py b/vnpy/trader/app/ctaStrategy/language/__init__.py index 4e1d5e96..8df93847 100644 --- a/vnpy/trader/app/ctaStrategy/language/__init__.py +++ b/vnpy/trader/app/ctaStrategy/language/__init__.py @@ -1,4 +1,4 @@ -# encoding: UTF-8 +# encoding: UTF-8 import json import os @@ -9,5 +9,5 @@ from vnpy.trader.app.ctaStrategy.language.chinese import text # 是否要使用英文 from vnpy.trader.vtGlobal import globalSetting -if globalSetting['language'] == 'english': +if len(globalSetting) > 0 and globalSetting['language'] == 'english': from vnpy.trader.app.ctaStrategy.language.english import text \ No newline at end of file diff --git a/vnpy/trader/app/ctaStrategy/strategy/__init__.py b/vnpy/trader/app/ctaStrategy/strategy/__init__.py index 53e31791..12fe5fe3 100644 --- a/vnpy/trader/app/ctaStrategy/strategy/__init__.py +++ b/vnpy/trader/app/ctaStrategy/strategy/__init__.py @@ -1,20 +1,42 @@ # encoding: UTF-8 ''' -动态载入所有的策略类 +动态载入所有的策略类,先从vnpy/trader/app/ctaStrategy/strategy下加载,其次,从工作目录下strategy加载。 +如果重复,工作目录的strategy优先。 ''' import os import importlib +import traceback # 用来保存策略类的字典 STRATEGY_CLASS = {} -# 获取目录路径 +# ---------------------------------------------------------------------- +def loadStrategyModule(moduleName): + """使用importlib动态载入模块""" + try: + print('loading {0}'.format(moduleName)) + module = importlib.import_module(moduleName) + + # 遍历模块下的对象,只有名称中包含'Strategy'的才是策略类 + for k in dir(module): + if 'Strategy' in k: + print('adding {} into STRATEGY_CLASS'.format(k)) + v = module.__getattribute__(k) + if k in STRATEGY_CLASS: + print('Replace strategy {} with {}'.format(k,moduleName)) + STRATEGY_CLASS[k] = v + except Exception as ex: + print('-' * 20) + print('Failed to import strategy file %s:' % moduleName) + print('Exception:{},{}'.format(str(ex),traceback.format_exc())) + + # 获取目录路径 path = os.path.abspath(os.path.dirname(__file__)) -print ('init {0}'.format(path)) +print('init strategies from {}'.format(path)) # 遍历strategy目录下的文件 for root, subdirs, files in os.walk(path): @@ -23,19 +45,20 @@ for root, subdirs, files in os.walk(path): if 'strategy' in name and '.pyc' not in name: # 模块名称需要上前缀 moduleName = 'vnpy.trader.app.ctaStrategy.strategy.' + name.replace('.py', '') - print ('loading {0}'.format(moduleName)) - try: - # 使用importlib动态载入模块 - module = importlib.import_module(moduleName) - except Exception as ex: - print ('load fail,excepion:{0}'.format(ex)) - continue + loadStrategyModule(moduleName) - # 遍历模块下的对象,只有名称中包含'Strategy'的才是策略类 - for k in dir(module): - if 'Strategy' in k: - print ('adding {0} into STRATEGY_CLASS'.format(k)) - v = module.__getattribute__(k) - STRATEGY_CLASS[k] = v -print( 'finished load strategy modules') \ No newline at end of file +# 遍历工作目录下的文件 +#stratey_working_path = os.path.abspath(os.path.join(os.getcwd(), 'strategy')) +# +#if os.path.exists(stratey_working_path): +# print('init strategies from {}'.format(stratey_working_path)) +# for root, subdirs, files in os.walk(stratey_working_path): +# for name in files: +# # 只有文件名中包含strategy且非.pyc的文件,才是策略文件 +# if 'strategy' in name and '.pyc' not in name: +# # 模块名称无需前缀 +# moduleName = name.replace('.py', '') +# loadStrategyModule(moduleName) +# +print('finished load strategy modules') diff --git a/vnpy/trader/app/spreadTrading/stAlgo.py b/vnpy/trader/app/spreadTrading/stAlgo.py index 2a5444d3..1b9fc2c5 100644 --- a/vnpy/trader/app/spreadTrading/stAlgo.py +++ b/vnpy/trader/app/spreadTrading/stAlgo.py @@ -496,7 +496,7 @@ class SniperAlgo(StAlgoTemplate): #---------------------------------------------------------------------- def cancelAllOrders(self): """撤销全部委托""" - for orderList in self.legOrderDict.values(): + for orderList in list(self.legOrderDict.values()): for vtOrderID in orderList: self.algoEngine.cancelOrder(vtOrderID) diff --git a/vnpy/trader/app/spreadTrading/stEngine.py b/vnpy/trader/app/spreadTrading/stEngine.py index 808bf565..f9c45dfa 100644 --- a/vnpy/trader/app/spreadTrading/stEngine.py +++ b/vnpy/trader/app/spreadTrading/stEngine.py @@ -274,7 +274,7 @@ class StDataEngine(object): #---------------------------------------------------------------------- def getAllSpreads(self): """获取所有的价差""" - return self.spreadDict.values() + return list(self.spreadDict.values()) ######################################################################## @@ -343,7 +343,7 @@ class StAlgoEngine(object): #---------------------------------------------------------------------- def processTimerEvent(self, event): """""" - for algo in self.algoDict.values(): + for algo in list(self.algoDict.values()): algo.updateTimer() #---------------------------------------------------------------------- @@ -450,7 +450,7 @@ class StAlgoEngine(object): def saveSetting(self): """保存算法配置""" setting = {} - for algo in self.algoDict.values(): + for algo in list(self.algoDict.values()): setting[algo.spreadName] = algo.getAlgoParams() f = shelve.open(self.algoFilePath) @@ -478,7 +478,7 @@ class StAlgoEngine(object): if not setting: return - for algo in self.algoDict.values(): + for algo in list(self.algoDict.values()): if algo.spreadName in setting: d = setting[algo.spreadName] algo.setAlgoParams(d) @@ -486,7 +486,7 @@ class StAlgoEngine(object): #---------------------------------------------------------------------- def stopAll(self): """停止全部算法""" - for algo in self.algoDict.values(): + for algo in list(self.algoDict.values()): algo.stop() #---------------------------------------------------------------------- @@ -506,7 +506,7 @@ class StAlgoEngine(object): #---------------------------------------------------------------------- def getAllAlgoParams(self): """获取所有算法的参数""" - return [algo.getAlgoParams() for algo in self.algoDict.values()] + return [algo.getAlgoParams() for algo in list(self.algoDict.values())] #---------------------------------------------------------------------- def setAlgoBuyPrice(self, spreadName, buyPrice): diff --git a/vnpy/trader/app/spreadTrading/uiStWidget.py b/vnpy/trader/app/spreadTrading/uiStWidget.py index cdecc3a9..1eb26a3d 100644 --- a/vnpy/trader/app/spreadTrading/uiStWidget.py +++ b/vnpy/trader/app/spreadTrading/uiStWidget.py @@ -429,7 +429,7 @@ class StAlgoManager(QtWidgets.QTableWidget): #---------------------------------------------------------------------- def stopAll(self): """停止所有算法""" - for button in self.buttonActiveDict.values(): + for button in list(self.buttonActiveDict.values()): button.stop() diff --git a/vnpy/trader/gateway/ctpGateway/ctpGateway.py b/vnpy/trader/gateway/ctpGateway/ctpGateway.py index c7cd21bf..15731757 100644 --- a/vnpy/trader/gateway/ctpGateway/ctpGateway.py +++ b/vnpy/trader/gateway/ctpGateway/ctpGateway.py @@ -383,10 +383,10 @@ class CtpMdApi(MdApi): def onRtnDepthMarketData(self, data): """行情推送""" # 忽略成交量为0的无效单合约tick数据 - if not data['Volume'] and '&' not in data['InstrumentID']: - self.writeLog(u'忽略成交量为0的无效单合约tick数据:') - self.writeLog(data) - return + #if not data['Volume'] and '&' not in data['InstrumentID']: + # self.writeLog(u'忽略成交量为0的无效单合约tick数据:') + # self.writeLog(data) + # return if not self.connectionStatus: self.connectionStatus = True @@ -407,11 +407,19 @@ class CtpMdApi(MdApi): #tick.time = '.'.join([data['UpdateTime'], str(data['UpdateMillisec']/100)]) # =》 Python 3 tick.time = '.'.join([data['UpdateTime'], str(data['UpdateMillisec'])]) - tick.date = data['TradingDay'] + # 上期所和郑商所可以直接使用,大商所需要转换 + tick.date = data['ActionDay'] + # 大商所日期转换 + if tick.exchange is EXCHANGE_DCE: + tick.date = datetime.now().strftime('%Y%m%d') + + #tick.date = data['TradingDay'] # add by Incense Lee tick.tradingDay = data['TradingDay'] - + if len(tick.tradingDay) == 8: + tradingDay = tick.tradingDay + tick.tradingDay = "{}-{}-{}".format(tradingDay[:4], tradingDay[4:6], tradingDay[6:]) # 先根据交易日期,生成时间 tick.datetime = datetime.strptime(tick.date + ' ' + tick.time, '%Y%m%d %H:%M:%S.%f') # 修正时间 @@ -429,6 +437,8 @@ class CtpMdApi(MdApi): tick.datetime = tick.datetime + timedelta(days=2) tick.date = tick.datetime.strftime('%Y-%m-%d') + tick.date = tick.datetime.strftime('%Y-%m-%d') + tick.openPrice = data['OpenPrice'] tick.highPrice = data['HighestPrice'] tick.lowPrice = data['LowestPrice'] @@ -1437,7 +1447,7 @@ class CtpTdApi(TdApi): """""" pass - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def connect(self, userID, password, brokerID, address, authCode, userProductInfo): """初始化连接""" self.userID = userID # 账号 @@ -1472,7 +1482,7 @@ class CtpTdApi(TdApi): elif not self.loginStatus: self.login() - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def login(self): """连接服务器""" # 如果填入了用户名密码等,则登录 @@ -1484,7 +1494,7 @@ class CtpTdApi(TdApi): self.reqID += 1 self.reqUserLogin(req, self.reqID) - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def authenticate(self): """申请验证""" if self.userID and self.brokerID and self.authCode and self.userProductInfo: @@ -1496,13 +1506,13 @@ class CtpTdApi(TdApi): self.reqID +=1 self.reqAuthenticate(req, self.reqID) - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def qryAccount(self): """查询账户""" self.reqID += 1 self.reqQryTradingAccount({}, self.reqID) - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def qryPosition(self): """查询持仓""" self.reqID += 1 @@ -1511,7 +1521,7 @@ class CtpTdApi(TdApi): req['InvestorID'] = self.userID self.reqQryInvestorPosition(req, self.reqID) - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def sendOrder(self, orderReq): """发单""" self.reqID += 1 @@ -1562,7 +1572,7 @@ class CtpTdApi(TdApi): vtOrderID = '.'.join([self.gatewayName, str(self.orderRef)]) return vtOrderID - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def cancelOrder(self, cancelOrderReq): """撤单""" self.reqID += 1 @@ -1581,12 +1591,12 @@ class CtpTdApi(TdApi): self.reqOrderAction(req, self.reqID) - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def close(self): """关闭""" self.exit() - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def writeLog(self, content): """发出日志""" log = VtLogData() diff --git a/vnpy/trader/language/chinese/constant.py b/vnpy/trader/language/chinese/constant.py index 5983fa70..1263a8d3 100644 --- a/vnpy/trader/language/chinese/constant.py +++ b/vnpy/trader/language/chinese/constant.py @@ -72,14 +72,21 @@ EXCHANGE_IDEALPRO = 'IDEALPRO' # IB外汇ECN EXCHANGE_CME = 'CME' # CME交易所 EXCHANGE_ICE = 'ICE' # ICE交易所 + EXCHANGE_OANDA = 'OANDA' # OANDA外汇做市商 + + EXCHANGE_OKCOIN = 'OKCOIN' # OKCOIN比特币交易所 EXCHANGE_HUOBI = 'HUOBI' # 火币比特币交易所 EXCHANGE_LHANG = 'LHANG' # 链行比特币交易所 EXCHANGE_OKEX = 'OKEX' # OKEX比特币交易所 EXCHANGE_BINANCE = 'BINANCE' # 币安比特币交易所 EXCHANGE_GATEIO = 'GATEIO' # gate.io比特币交易所 -EXCHANGE_FCOIN = 'FCOIN' # fcoin.com 比特币交易所 +EXCHANGE_BITFINEX = "BITFINEX" # Bitfinex比特币交易所 +EXCHANGE_BITMEX = 'BITMEX' # BitMEX比特币交易所 +EXCHANGE_FCOIN = 'FCOIN' # FCoin比特币交易所 +EXCHANGE_BIGONE = 'BIGONE' # BigOne比特币交易所 + # 货币类型 CURRENCY_USD = 'USD' # 美元 CURRENCY_CNY = 'CNY' # 人民币 diff --git a/vnpy/trader/language/english/constant.py b/vnpy/trader/language/english/constant.py index da3d6043..506057aa 100644 --- a/vnpy/trader/language/english/constant.py +++ b/vnpy/trader/language/english/constant.py @@ -80,7 +80,11 @@ EXCHANGE_LHANG = 'LHANG' # 链行比特币交易所 EXCHANGE_OKEX = 'OKEX' # OKEX比特币交易所 EXCHANGE_BINANCE = 'BINANCE' # 币安比特币交易所 EXCHANGE_GATEIO = 'GATEIO' # gate.io比特币交易所 -EXCHANGE_FCOIN = 'FCOIN' # fcoin.com 比特币交易所 +EXCHANGE_BITFINEX = "BITFINEX" # Bitfinex比特币交易所 +EXCHANGE_BITMEX = 'BITMEX' # BitMEX比特币交易所 +EXCHANGE_FCOIN = 'FCOIN' # FCoin比特币交易所 +EXCHANGE_BIGONE = 'BIGONE' # BigOne比特币交易所 + # 货币类型 CURRENCY_USD = 'USD' # 美元 CURRENCY_CNY = 'CNY' # 人民币 diff --git a/vnpy/trader/uiKLine/uiMulti4KLine.py b/vnpy/trader/uiKLine/uiMulti4KLine.py index 7eddb21c..dfba299d 100644 --- a/vnpy/trader/uiKLine/uiMulti4KLine.py +++ b/vnpy/trader/uiKLine/uiMulti4KLine.py @@ -26,7 +26,7 @@ class GridKline(QtWidgets.QWidget): self.parent = parent super(GridKline, self).__init__(parent) - self.periods = ['m30', 'h1', 'h2', 'd'] + self.periods = ['m30', 'h1', 'h2','d'] self.kline_dict = {} self.initUI() @@ -71,19 +71,23 @@ class GridKline(QtWidgets.QWidget): df = df.set_index(pd.DatetimeIndex(df['datetime'])) canvas.loadData(df, main_indicators=['ma5', 'ma10', 'ma18'], sub_indicators=['sk', 'sd']) - # 载入 回测引擎生成的成交记录 trade_list_file = 'TradeList.csv' if os.path.exists(trade_list_file): df_trade = pd.read_csv(trade_list_file) self.kline_dict['h1'].add_signals(df_trade) - # 载入策略生成的交易事务过程 tns_file = 'tns.csv' if os.path.exists(tns_file): df_tns = pd.read_csv(tns_file) self.kline_dict['h2'].add_trans_df(df_tns) self.kline_dict['d'].add_trans_df(df_tns) + markup_file = 'dist.csv' + if os.path.exists(markup_file): + df_markup = pd.read_csv(markup_file) + df_markup = df_markup[['datetime', 'price', 'operation']] + df_markup.rename(columns={'operation': 'markup'}, inplace=True) + self.kline_dict['m30'].add_markups(df_markup=df_markup, include_list=['balance'], exclude_list=['buy', 'short', 'sell', 'cover']) except Exception as ex: traceback.print_exc() diff --git a/vnpy/trader/util_mail.py b/vnpy/trader/util_mail.py index 9b55dd3e..3ef5913b 100644 --- a/vnpy/trader/util_mail.py +++ b/vnpy/trader/util_mail.py @@ -6,7 +6,7 @@ from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart import smtplib -from threading import * +from threading import Lock,Thread import time # 创建一个带附件的实例 @@ -15,6 +15,23 @@ global maillock maillock = Lock() +#用于发送邮件的邮箱列表 +senders_list = [('xxx010@163.com','xxxx'), + ('xxx013@163.com','xxxx'), + ('xxx014@163.com','xxxx'), + ('xxx015@163.com','xxxx'), + ('xxx016@163.com','xxxx'), + ('xxx017@163.com','xxxx'), + ('xxx018@163.com','xxxx'), + ('xxx019@163.com','xxxx'), + ('xxx020@163.com','xxxx'), + ('xxx021@163.com','xxxx'), + ('xxx022@163.com','xxxx'), + ('xxx023@163.com','xxxx'), + ('xxx024@163.com','xxxx'), + ('xxx025@163.com','xxxx'), + ('xxx026@163.com','xxxx'), + ('xxx027@163.com','xxxx')] class mail_thread(Thread): def __init__(self, to_list, subject, msgcontent, attachlist): super(mail_thread, self).__init__(name="mail_thread") @@ -36,6 +53,16 @@ class mail_thread(Thread): self.lock.acquire() print("lock acquire %s" % time.ctime()) + random_limit = len(senders_list) - 1 + if random_limit != 0: + index = random.randint(0, random_limit) + self.mailfrom, self.mailpwd = senders_list[index] + + if len(self.mailfrom)==0 or len(self.mailpwd) == 0: + print("sendmail user/pwd error!") + self.lock.release() + return + msg = MIMEMultipart() # 文本肉容 content = MIMEText(self.msgcontent, _subtype='plain', _charset='gb2312') diff --git a/vnpy/trader/vtEngine.py b/vnpy/trader/vtEngine.py index f1856d2a..d6554e3f 100644 --- a/vnpy/trader/vtEngine.py +++ b/vnpy/trader/vtEngine.py @@ -7,7 +7,7 @@ from collections import OrderedDict import os,sys import copy -from pymongo import MongoClient +from pymongo import MongoClient, ASCENDING from pymongo.errors import ConnectionFailure,AutoReconnect #import vnpy.trader.mongo_proxy @@ -26,7 +26,11 @@ import psutil try: from .util_mail import * except: - print('import util_mail fail') + print('import util_mail fail',file=sys.stderr) +try: + from .util_wechat import * +except: + print('import util_wechat fail',file=sys.stderr) LOG_DB_NAME = 'vt_logger' @@ -147,16 +151,16 @@ class MainEngine(object): def checkGatewayStatus(self,gatewayName): """check gateway connect status""" + # 借用检查网关状态来持久化合约数据 + self.save_contract_counter += 1 + if self.save_contract_counter > 60 and self.dataEngine is not None: + self.writeLog(u'保存持久化合约数据') + self.dataEngine.saveContracts() + self.save_contract_counter = 0 + if gatewayName in self.gatewayDict: gateway = self.gatewayDict[gatewayName] - # 借用检查网关状态来持久化合约数据 - self.save_contract_counter += 1 - if self.save_contract_counter > 60 and self.dataEngine is not None: - - self.dataEngine.saveContracts() - self.save_contract_counter = 0 - return gateway.checkStatus() else: self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName)) @@ -372,11 +376,22 @@ class MainEngine(object): # 写入本地log日志 if self.logger is not None: self.logger.error(content) + print('{}'.format(datetime.now()),file=sys.stderr) print(content, file=sys.stderr) else: print(content, file=sys.stderr) self.createLogger() + # 发出邮件/微信 + #try: + # if len(self.gatewayDetailList) > 0: + # target = self.gatewayDetailList[0]['gatewayName'] + # else: + # target = WECHAT_GROUP["DEBUG_01"] + # sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_ERROR) + #except Exception as ex: + # print(u'send wechat exception:{}'.format(str(ex)),file=sys.stderr) + # ---------------------------------------------------------------------- def writeWarning(self, content): """快速发出告警日志事件""" @@ -399,6 +414,16 @@ class MainEngine(object): except: pass + # 发出微信 + #try: + # if len(self.gatewayDetailList) > 0: + # target = self.gatewayDetailList[0]['gatewayName'] + # else: + # target = WECHAT_GROUP["DEBUG_01"] + # sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_WARNING) + #except Exception as ex: + # print(u'send wechat exception:{}'.format(str(ex)), file=sys.stderr) + # ---------------------------------------------------------------------- def writeNotification(self, content): """快速发出通知日志事件""" @@ -414,6 +439,16 @@ class MainEngine(object): except: pass + # 发出微信 + # try: + # if len(self.gatewayDetailList) > 0: + # target = self.gatewayDetailList[0]['gatewayName'] + # else: + # target = WECHAT_GROUP["DEBUG_01"] + # sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_INFO) + # except Exception as ex: + # print(u'send wechat exception:{}'.format(str(ex)), file=sys.stderr) + # ---------------------------------------------------------------------- def writeCritical(self, content): """快速发出严重错误日志事件""" @@ -427,6 +462,7 @@ class MainEngine(object): # 写入本地log日志 if self.logger: self.logger.critical(content) + print('{}'.format(datetime.now()), file=sys.stderr) print(content, file=sys.stderr) else: print(content, file=sys.stderr) @@ -438,6 +474,16 @@ class MainEngine(object): except: pass + ## 发出微信 + #try: + # # if len(self.gatewayDetailList) > 0: + # target = self.gatewayDetailList[0]['gatewayName'] + # else: + # target = WECHAT_GROUP["DEBUG_01"] + # sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_FATAL) + #except: + # pass +# # ---------------------------------------------------------------------- def dbConnect(self): """连接MongoDB数据库""" @@ -527,13 +573,18 @@ class MainEngine(object): self.writeError(u'dbInsertMany exception:{}'.format(str(ex))) # ---------------------------------------------------------------------- - def dbQuery(self, dbName, collectionName, d): + def dbQuery(self, dbName, collectionName, d, sortKey='', sortDirection=ASCENDING): """从MongoDB中读取数据,d是查询要求,返回的是数据库查询的指针""" try: if self.dbClient: db = self.dbClient[dbName] collection = db[collectionName] - cursor = collection.find(d) + + if sortKey: + cursor = collection.find(d).sort(sortKey, sortDirection) # 对查询出来的数据进行排序 + else: + cursor = collection.find(d) + if cursor: return list(cursor) else: diff --git a/vnpy/trader/vtGateway.py b/vnpy/trader/vtGateway.py index 18909260..747cbcc1 100644 --- a/vnpy/trader/vtGateway.py +++ b/vnpy/trader/vtGateway.py @@ -1,6 +1,6 @@ # encoding: UTF-8 -import time,os +import time,os,sys from datetime import datetime from vnpy.trader.vtEvent import * @@ -98,10 +98,11 @@ class VtGateway(object): event1.dict_['data'] = error self.eventEngine.put(event1) - logMsg = u'{0}:[{1}]:{2}'.format(error.gatewayName, error.errorID,error.errorMsg ) + logMsg = u'{} {}:[{}]:{}'.format(datetime.now(), error.gatewayName, error.errorID,error.errorMsg ) # 写入本地log日志 if self.logger: - self.logger.info(logMsg) + self.logger.error(logMsg) + print(logMsg,file=sys.stderr) else: self.createLogger() @@ -205,12 +206,9 @@ class VtGateway(object): error.errorMsg = content self.onError(error) + # 输出到错误管道 + print(u'{}:{} {}'.format(datetime.now(),self.gatewayName,content),file=sys.stderr) + if self.logger: self.logger.error(content) - - - - - -