diff --git a/vnpy/trader/gateway/ctpGateway/ctpGateway.py b/vnpy/trader/gateway/ctpGateway/ctpGateway.py index ca767ae0..6c26526f 100644 --- a/vnpy/trader/gateway/ctpGateway/ctpGateway.py +++ b/vnpy/trader/gateway/ctpGateway/ctpGateway.py @@ -1,24 +1,26 @@ # encoding: UTF-8 -''' -vn.ctp的gateway接入 +""" +ctp的gateway接入 -考虑到现阶段大部分CTP中的ExchangeID字段返回的都是空值 -vtSymbol直接使用symbol -''' -print('loading ctpGateway.py') -import os +1、增加通达信行情(指数行情)的订阅 +2、增加支持Bar行情的订阅 +3、增加支持自定义套利合约的订阅 +""" + +import os,sys import json +import redis # 加载经booster编译转换的SO API库 from vnpy.trader.gateway.ctpGateway.vnctpmd import MdApi from vnpy.trader.gateway.ctpGateway.vnctptd import TdApi -print(u'loaded vnctpmd/vnctptd') + from vnpy.trader.vtConstant import * from vnpy.trader.vtGateway import * from vnpy.trader.gateway.ctpGateway.language import text from vnpy.trader.gateway.ctpGateway.ctpDataType import * -from vnpy.trader.vtFunction import getJsonPath,getShortSymbol +from vnpy.trader.vtFunction import getJsonPath,getShortSymbol,roundToPriceTick from vnpy.trader.app.ctaStrategy.ctaBase import MARKET_DAY_ONLY,NIGHT_MARKET_SQ1,NIGHT_MARKET_SQ2,NIGHT_MARKET_SQ3,NIGHT_MARKET_ZZ,NIGHT_MARKET_DL from datetime import datetime,timedelta @@ -95,7 +97,6 @@ symbolExchangeDict = {} # 夜盘交易时间段分隔判断 NIGHT_TRADING = datetime(1900, 1, 1, 20).time() - ######################################################################## class CtpGateway(VtGateway): """CTP接口""" @@ -108,7 +109,7 @@ class CtpGateway(VtGateway): self.mdApi = None # 行情API self.tdApi = None # 交易API self.tdxApi = None # 通达信指数行情API - + self.redisApi = None # redis行情API self.mdConnected = False # 行情API连接状态,登录完成后为True self.tdConnected = False # 交易API连接状态 @@ -120,9 +121,16 @@ class CtpGateway(VtGateway): self.subscribedSymbols = set() # 已订阅合约代码 self.requireAuthentication = False + self.debug_tick = False + self.tdx_pool_count = 2 # 通达信连接池内连接数 - #---------------------------------------------------------------------- + self.combiner_conf_dict = {} # 保存合成器配置 + # 自定义价差/加比的tick合成器 + self.combiners = {} + self.tick_combiner_map = {} + + # ---------------------------------------------------------------------- def connect(self): """连接""" # 载入json文件 @@ -157,6 +165,8 @@ class CtpGateway(VtGateway): tdAddress = str(setting['tdAddress']) mdAddress = str(setting['mdAddress']) + self.debug_tick = setting.get('debug_tick',False) + # 如果json文件提供了验证码 if 'authCode' in setting: authCode = str(setting['authCode']) @@ -166,8 +176,22 @@ class CtpGateway(VtGateway): authCode = None userProductInfo = None - # 如果没有初始化tdxApi - if self.tdxApi is None: + # 获取redis行情配置 + redis_conf = setting.get('redis', None) + if redis_conf is not None and isinstance(redis_conf, dict): + if self.redisApi is None: + self.writeLog(u'RedisApi接口未实例化,创建实例') + self.redisApi = RedisMdApi(self) # Redis行情API + else: + self.writeLog(u'Redis行情接口已实例化') + + ip_list = redis_conf.get('ip_list', None) + if ip_list is not None and len(ip_list) > 0: + self.writeLog(u'使用配置文件的redis服务器清单:{}'.format(ip_list)) + self.redisApi.ip_list = copy.copy(ip_list) + + # 如果没有初始化restApi,就初始化tdxApi + if self.redisApi is None and self.tdxApi is None: self.writeLog(u'通达信接口未实例化,创建实例') self.tdxApi = TdxMdApi(self) # 通达信行情API @@ -185,6 +209,15 @@ class CtpGateway(VtGateway): # 获取通达信得缺省连接池数量 self.tdx_pool_count = tdx_conf.get('pool_count', self.tdx_pool_count) + # 获取自定义价差/价比合约的配置 + try: + from vnpy.trader.vtEngine import Custom_Contract + c = Custom_Contract() + self.combiner_conf_dict = c.get_config() + if len(self.combiner_conf_dict)>0: + self.writeLog(u'加载的自定义价差/价比配置:{}'.format(self.combiner_conf_dict)) + except Exception as ex: + pass except KeyError: self.writeLog(text.CONFIG_KEY_MISSING) return @@ -207,38 +240,98 @@ class CtpGateway(VtGateway): self.writeLog(u'有指数订阅,连接通达信行情服务器') self.tdxApi.connect(self.tdx_pool_count) self.tdxApi.subscribe(req) + elif self.redisApi is not None: + self.writeLog(u'有指数订阅,连接Redis行情服务器') + self.redisApi.connect() + self.redisApi.subscribe(req) else: self.mdApi.subscribe(req) - + + def add_spread_conf(self, conf): + """添加价差行情配置""" + self.writeLog(u'添加价差行情配置:{}'.format(conf)) + #---------------------------------------------------------------------- def subscribe(self, subscribeReq): """订阅行情""" - if self.mdApi is not None: - # 指数合约,从tdx行情订阅 - if subscribeReq.symbol[-2:] in ['99']: - subscribeReq.symbol = subscribeReq.symbol.upper() - if self.tdxApi: - self.tdxApi.subscribe(subscribeReq) + try: + if self.mdApi is not None: - else: - self.mdApi.subscribe(subscribeReq) + # 如果是自定义的套利合约符号 + if subscribeReq.symbol in self.combiner_conf_dict: + self.writeLog(u'订阅自定义套利合约:{}'.format(subscribeReq.symbol)) + # 创建合成器 + if subscribeReq.symbol not in self.combiners: + setting = self.combiner_conf_dict.get(subscribeReq.symbol) + setting.update({"vtSymbol":subscribeReq.symbol}) + combiner = TickCombiner(self, setting) + # 更新合成器 + self.writeLog(u'添加{}与合成器映射'.format(subscribeReq.symbol)) + self.combiners.update({setting.get('vtSymbol'): combiner}) - # Allow the strategies to start before the connection - self.subscribedSymbols.add(subscribeReq) + # 增加映射( leg1 对应的合成器列表映射) + leg1_symbol = setting.get('leg1_symbol') + combiner_list = self.tick_combiner_map.get(leg1_symbol, []) + if combiner not in combiner_list: + self.writeLog(u'添加Leg1:{}与合成器得映射'.format(leg1_symbol)) + combiner_list.append(combiner) + self.tick_combiner_map.update({leg1_symbol: combiner_list}) + + # 增加映射( leg2 对应的合成器列表映射) + leg2_symbol = setting.get('leg2_symbol') + combiner_list = self.tick_combiner_map.get(leg2_symbol, []) + if combiner not in combiner_list: + self.writeLog(u'添加Leg2:{}与合成器得映射'.format(leg2_symbol)) + combiner_list.append(combiner) + self.tick_combiner_map.update({leg2_symbol: combiner_list}) + + self.writeLog(u'订阅leg1:{}'.format(leg1_symbol)) + leg1_req = VtSubscribeReq() + leg1_req.symbol = leg1_symbol + leg1_req.exchange = subscribeReq.exchange + self.subscribe(leg1_req) + + self.writeLog(u'订阅leg2:{}'.format(leg2_symbol)) + leg2_req = VtSubscribeReq() + leg2_req.symbol = leg2_symbol + leg2_req.exchange = subscribeReq.exchange + self.subscribe(leg2_req) + + self.subscribedSymbols.add(subscribeReq) + else: + self.writeLog(u'{}合成器已经在存在'.format(subscribeReq.symbol)) + return + elif subscribeReq.symbol.endswith('SPD'): + self.writeError(u'自定义合约{}不在CTP设置中'.format(subscribeReq.symbol)) + + # 指数合约,从tdx行情订阅 + if subscribeReq.symbol[-2:] in ['99']: + subscribeReq.symbol = subscribeReq.symbol.upper() + if self.tdxApi: + self.tdxApi.subscribe(subscribeReq) + elif self.redisApi: + self.redisApi.subscribe(subscribeReq) + else: + self.mdApi.subscribe(subscribeReq) + + # Allow the strategies to start before the connection + self.subscribedSymbols.add(subscribeReq) + except Exception as ex: + self.writeError(u'订阅合约异常:{},{}'.format(str(ex),traceback.format_exc())) - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def sendOrder(self, orderReq): """发单""" if self.tdApi is not None: return self.tdApi.sendOrder(orderReq) - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def cancelOrder(self, cancelOrderReq): """撤单""" if self.tdApi is not None: self.tdApi.cancelOrder(cancelOrderReq) - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def qryAccount(self): """查询账户资金""" if self.tdApi is None: @@ -246,7 +339,7 @@ class CtpGateway(VtGateway): return self.tdApi.qryAccount() - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def qryPosition(self): """查询持仓""" if self.tdApi is None: @@ -342,6 +435,13 @@ class CtpGateway(VtGateway): log.logContent = content self.onLog(log) + def onCustomerTick(self,tick): + """推送自定义合约行情""" + # 自定义合约行情 + + for combiner in self.tick_combiner_map.get(tick.vtSymbol, []): + tick = copy.copy(tick) + combiner.onTick(tick) ######################################################################## class CtpMdApi(MdApi): @@ -366,7 +466,7 @@ class CtpMdApi(MdApi): self.password = EMPTY_STRING # 密码 self.brokerID = EMPTY_STRING # 经纪商代码 self.address = EMPTY_STRING # 服务器地址 - + #---------------------------------------------------------------------- def onFrontConnected(self): """服务器连接""" @@ -486,7 +586,7 @@ class CtpMdApi(MdApi): # 不处理开盘前的tick数据 if dt.hour in [8,20] and dt.minute < 59: return - if tick.exchange is EXCHANGE_CFFEX and dt.hour ==9 and dt.minute < 14: + if tick.exchange is EXCHANGE_CFFEX and dt.hour == 9 and dt.minute < 14: return # 日期,取系统时间的日期 @@ -528,7 +628,45 @@ class CtpMdApi(MdApi): tick.askPrice1 = data['AskPrice1'] tick.askVolume1 = data['AskVolume1'] + if data.get('BidPrice2',None) !=None: + tick.bidPrice2 = data.get('BidPrice2') + if data.get('BidPrice3', None) != None: + tick.bidPrice3 = data.get('BidPrice3') + if data.get('BidPrice4', None) != None: + tick.bidPrice4 = data.get('BidPrice4') + if data.get('BidPrice5', None) != None: + tick.bidPrice5 = data.get('BidPrice5') + + if data.get('AskPrice2',None) !=None: + tick.AskPrice2 = data.get('AskPrice2') + if data.get('AskPrice3', None) != None: + tick.AskPrice3 = data.get('AskPrice3') + if data.get('AskPrice4', None) != None: + tick.AskPrice4 = data.get('AskPrice4') + if data.get('AskPrice5', None) != None: + tick.AskPrice5 = data.get('AskPrice5') + + if data.get('BidVolume2',None) !=None: + tick.bidVolume2 = data.get('BidVolume2') + if data.get('BidVolume3', None) != None: + tick.bidVolume3 = data.get('BidVolume3') + if data.get('BidVolume4', None) != None: + tick.bidVolume4 = data.get('BidVolume4') + if data.get('BidVolume5', None) != None: + tick.bidVolume5 = data.get('BidVolume5') + + if data.get('AskVolume2',None) !=None: + tick.AskVolume2 = data.get('AskVolume2') + if data.get('AskVolume3', None) != None: + tick.AskVolume3 = data.get('AskVolume3') + if data.get('AskVolume4', None) != None: + tick.AskVolume4 = data.get('AskVolume4') + if data.get('AskVolume5', None) != None: + tick.AskVolume5 = data.get('AskVolume5') + self.gateway.onTick(tick) + + self.gateway.onCustomerTick(tick) #---------------------------------------------------------------------- def onRspSubForQuoteRsp(self, data, error, n, last): @@ -577,15 +715,11 @@ class CtpMdApi(MdApi): """订阅合约""" # 这里的设计是,如果尚未登录就调用了订阅方法 # 则先保存订阅请求,登录完成后会自动订阅 - #if self.loginStatus: - print(u'subscribe {0}'.format(str(subscribeReq.symbol))) + + # 订阅传统合约 self.subscribeMarketData(str(subscribeReq.symbol)) self.writeLog(u'订阅合约:{0}'.format(str(subscribeReq.symbol))) - #else: - # print u'not login, add {0} into subscribe list'.format(str(subscribeReq.symbol)) - # self.writeLog(u'未连接,增加合约{0}至待订阅列表'.format(str(subscribeReq.symbol))) - - self.subscribedSymbols.add(subscribeReq) + self.subscribedSymbols.add(subscribeReq) #---------------------------------------------------------------------- def login(self): @@ -879,51 +1013,77 @@ class CtpTdApi(TdApi): if not self.gateway.tdConnected: self.gateway.tdConnected = True - # 获取持仓缓存对象 - posName = '.'.join([data['InstrumentID'], data['PosiDirection']]) - if posName in self.posDict: - pos = self.posDict[posName] - else: - pos = VtPositionData() - self.posDict[posName] = pos - - pos.gatewayName = self.gatewayName - pos.symbol = data['InstrumentID'] - pos.vtSymbol = pos.symbol - pos.direction = posiDirectionMapReverse.get(data['PosiDirection'], '') - pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction, pos.gatewayName]) + try: + # 获取持仓缓存对象 + posName = '.'.join([data['InstrumentID'], data['PosiDirection']]) + if posName in self.posDict: + pos = self.posDict[posName] + else: + pos = VtPositionData() + self.posDict[posName] = pos - # 针对上期所持仓的今昨分条返回(有昨仓、无今仓),读取昨仓数据 - if data['YdPosition'] and not data['TodayPosition']: - pos.ydPosition = data['Position'] + pos.gatewayName = self.gatewayName + pos.symbol = data['InstrumentID'] + pos.vtSymbol = pos.symbol + pos.direction = posiDirectionMapReverse.get(data['PosiDirection'], '') + pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction, pos.gatewayName]) - # 计算成本 - cost = pos.price * pos.position + exchange = self.symbolExchangeDict.get(pos.symbol, EXCHANGE_UNKNOWN) - # 汇总总仓 - pos.position += data['Position'] - pos.positionProfit += data['PositionProfit'] + # 针对上期所持仓的今昨分条返回(有昨仓、无今仓),读取昨仓数据 + if exchange == EXCHANGE_SHFE: + if data['YdPosition'] and not data['TodayPosition']: + pos.ydPosition = data['Position'] + # 否则基于总持仓和今持仓来计算昨仓数据 + else: + pos.ydPosition = data['Position'] - data['TodayPosition'] - # 计算持仓均价 - if pos.position and pos.symbol in self.symbolSizeDict: + # 计算成本 + if pos.symbol not in self.symbolSizeDict: + return size = self.symbolSizeDict[pos.symbol] - if size > 0 and pos.position > 0: - pos.price = (cost + data['PositionCost']) / abs(pos.position * size) + cost = pos.price * pos.position * size - # 读取冻结 - if pos.direction is DIRECTION_LONG: - pos.frozen += data['LongFrozen'] - else: - pos.frozen += data['ShortFrozen'] + # 汇总总仓 + pos.position += data['Position'] + + # 计算持仓均价 + if pos.position and size: + #pos.price = (cost + data['PositionCost']) / (pos.position * size) + pos.price = (cost + data['OpenCost']) / (pos.position * size) + + # 上一交易日结算价 + pre_settlement_price = data['PreSettlementPrice'] + # 开仓均价 + open_cost_price = data['OpenCost'] / (pos.position * size) + # 逐笔盈亏 = (上一交易日结算价 - 开仓价)* 持仓数量 * 杠杆 + 当日持仓收益 + if pos.direction == DIRECTION_LONG: + pre_profit = (pre_settlement_price - open_cost_price) * (pos.position * size) + else: + pre_profit = (open_cost_price - pre_settlement_price) * (pos.position * size) + + pos.positionProfit = pos.positionProfit + pre_profit + data['PositionProfit'] + + # 读取冻结 + if pos.direction is DIRECTION_LONG: + pos.frozen += data['LongFrozen'] + else: + pos.frozen += data['ShortFrozen'] + + # 查询回报结束 + if last: + # 遍历推送 + if self.gateway.debug: + print(u'最后推送') + for pos in list(self.posDict.values()): + self.gateway.onPosition(pos) + + # 清空缓存 + self.posDict.clear() + except Exception as ex: + self.gateway.writeError('onRspQryInvestorPosition exception:{}'.format(str(ex))) + self.gateway.writeError('traceL{}'.format(traceback.format_exc())) - # 查询回报结束 - if last: - # 遍历推送 - for pos in list(self.posDict.values()): - self.gateway.onPosition(pos) - - # 清空缓存 - self.posDict.clear() #---------------------------------------------------------------------- def onRspQryTradingAccount(self, data, error, n, last): @@ -1024,9 +1184,23 @@ class CtpTdApi(TdApi): self.symbolExchangeDict[contract.symbol] = contract.exchange self.symbolSizeDict[contract.symbol] = contract.size + idx_contract = copy.copy(contract) + # 推送 self.gateway.onContract(contract) - + + # 生成指数合约信息 + short_symbol= getShortSymbol(idx_contract.symbol).upper() # 短合约名称 + # 只推送普通合约的指数 + if short_symbol!= idx_contract.symbol.upper() and len(short_symbol)<=2 and contract.optionType==EMPTY_UNICODE: + idx_contract.symbol = '{}99'.format(short_symbol) + idx_contract.vtSymbol = idx_contract.symbol + idx_contract.name = u'{}指数'.format(short_symbol.upper()) + #self.writeLog(u'更新指数{}的合约信息,size:{}, longMarginRatio:{},shortMarginRatio{}' + # .format(idx_contract.vtSymbol,idx_contract.size, idx_contract.longMarginRatio,contract.shortMarginRatio)) + + self.gateway.onContract(idx_contract) + if last: self.writeLog(text.CONTRACT_DATA_RECEIVED) @@ -1704,12 +1878,26 @@ class TdxMdApi(): self.pool = None # 线程池 #self.req_thread = None # 定时器线程 - self.ip_list = [{'ip': '112.74.214.43', 'port': 7727}, - {'ip': '59.175.238.38', 'port': 7727}, - {'ip': '124.74.236.94', 'port': 7721}, - {'ip': '124.74.236.94', 'port': 7721}, - {'ip': '58.246.109.27', 'port': 7721} - ] + self.ip_list = [ + {"ip": "106.14.95.149", "port": 7727, "name": "扩展市场上海双线"}, + #{"ip": "112.74.214.43", "port": 7727, "name": "扩展市场深圳双线1"}, + # {"ip": "113.105.142.136", "port": 443, "name": "扩展市场东莞主站"}, + {"ip": "119.147.86.171", "port": 7727, "name": "扩展市场深圳主站"}, + {"ip": "119.97.185.5", "port": 7727, "name": "扩展市场武汉主站1"}, + {"ip": "120.24.0.77", "port": 7727, "name": "扩展市场深圳双线2"}, + {"ip": "124.74.236.94", "port": 7721}, + {"ip": "202.103.36.71", "port": 443, "name": "扩展市场武汉主站2"}, + {"ip": "47.92.127.181", "port": 7727, "name": "扩展市场北京主站"}, + {"ip": "59.175.238.38", "port": 7727, "name": "扩展市场武汉主站3"}, + {"ip": "61.152.107.141", "port": 7727, "name": "扩展市场上海主站1"}, + {"ip": "61.152.107.171", "port": 7727, "name": "扩展市场上海主站2"}, + # {"ip": "124.74.236.94","port": 7721}, + # {"ip": "218.80.248.229", "port": 7721}, + # {"ip": "58.246.109.27", "port": 7721}, + # added 20190222 from tdx + {"ip": "119.147.86.171", "port": 7721, "name": "扩展市场深圳主站"}, + {"ip": "47.107.75.159", "port": 7727, "name": "扩展市场深圳双线3"}, + ] # 调出 {'ip': '218.80.248.229', 'port': 7721}, self.best_ip = {'ip': None, 'port': None} @@ -1717,6 +1905,9 @@ class TdxMdApi(): self.last_tick_dt = {} # 记录该会话对象的最后一个tick时间 self.instrument_count = 50000 + + self.has_qry_instrument = False + # ---------------------------------------------------------------------- def ping(self, ip, port=7709): """ @@ -1892,6 +2083,9 @@ class TdxMdApi(): if not self.connection_status: return + if self.has_qry_instrument: + return + api = self.api_dict.get(0) if api is None: self.writeLog(u'取不到api连接,更新合约信息失败') @@ -1921,7 +2115,7 @@ class TdxMdApi(): elif tdx_market_id == 30: # 上期所+能源 self.symbol_exchange_dict[tdx_symbol] = EXCHANGE_SHFE - # 如果有预定的订阅合约,提前订阅 + self.has_qry_instrument = True def run(self, i): """ @@ -1992,7 +2186,7 @@ class TdxMdApi(): #self.writeLog(u'tdx[{}] get_instrument_quote:({},{})'.format(i,self.symbol_market_dict.get(symbol),symbol)) rt_list = api.get_instrument_quote(self.symbol_market_dict.get(symbol),symbol) - if len(rt_list) == 0: + if rt_list is None or len(rt_list) == 0: self.writeLog(u'tdx[{}]: rt_list为空'.format(i)) return #else: @@ -2077,10 +2271,16 @@ class TdxMdApi(): if tick.exchange is EXCHANGE_CFFEX: if tick.datetime.hour not in [9,10,11,13,14,15]: return - if tick.datetime.hour == 9 and tick.datetime.minute < 15: - return - if tick.datetime.hour == 15 and tick.datetime.minute >= 15: - return + if short_symbol in ['IC','IF','IH']: + if tick.datetime.hour == 9 and tick.datetime.minute < 30: + return + if tick.datetime.hour == 15 and tick.datetime.minute >= 0: + return + else: # TF, T + if tick.datetime.hour == 9 and tick.datetime.minute < 15: + return + if tick.datetime.hour == 15 and tick.datetime.minute >= 15: + return else: # 大商所/郑商所,上期所,上海能源 # 排除非开盘小时 if tick.datetime.hour in [3,4,5,6,7,8,12,15,16,17,18,19,20]: @@ -2099,14 +2299,14 @@ class TdxMdApi(): return # 排除大商所/郑商所夜盘数据 - if short_symbol in NIGHT_MARKET_DL or short_symbol in NIGHT_MARKET_ZZ: + if short_symbol in NIGHT_MARKET_ZZ: if tick.datetime.hour == 23 and tick.datetime.minute>=30: return if tick.datetime.hour in [0,1,2]: return # 排除上期所夜盘数据 23:00 收盘 - if short_symbol in NIGHT_MARKET_SQ3: + if short_symbol in NIGHT_MARKET_SQ3 or short_symbol in NIGHT_MARKET_DL: if tick.datetime.hour in [23,0,1,2]: return # 排除上期所夜盘数据 1:00 收盘 @@ -2118,8 +2318,9 @@ class TdxMdApi(): if short_symbol in MARKET_DAY_ONLY and (tick.datetime.hour < 9 or tick.datetime.hour > 16): #self.writeLog(u'排除日盘合约{}在夜盘得数据'.format(short_symbol)) return - """ - self.writeLog('{},{},{},{},{},{},{},{},{},{},{},{},{},{}'.format(tick.gatewayName, tick.symbol, + + if self.gateway.debug_tick: + self.writeLog('tdx {},{},{},{},{},{},{},{},{},{},{},{},{},{}'.format(tick.gatewayName, tick.symbol, tick.exchange, tick.vtSymbol, tick.datetime, tick.tradingDay, tick.openPrice, tick.highPrice, @@ -2127,11 +2328,12 @@ class TdxMdApi(): tick.bidPrice1, tick.bidVolume1, tick.askPrice1, tick.askVolume1)) - """ + self.symbol_tick_dict[symbol] = tick self.gateway.onTick(tick) + self.gateway.onCustomerTick(tick) # ---------------------------------------------------------------------- def writeLog(self, content): @@ -2144,6 +2346,644 @@ class TdxMdApi(): def writeError(self,content): self.gateway.writeError(content) +class RedisMdApi(): + """ + Redis数据行情API实现 + 通过线程,查询订阅的行情,更新合约的数据 + + """ + + def __init__(self, gateway): + """Constructor""" + self.EventType = "RedisQuotation" + self.gateway = gateway # gateway对象 + self.gatewayName = gateway.gatewayName # gateway对象名称 + + self.req_interval = 0.5 # 操作请求间隔500毫秒 + self.connection_status = False # 连接状态 + + self.symbol_tick_dict = {} # Redis合约与最后一个Tick得字典 + self.registed_symbol_set = set() + + # Redis服务器列表 + self.ip_list = [{'ip': '192.168.1.211', 'port': 6379}, + {'ip': '192.168.1.212', 'port': 6379}, + {'ip': '192.168.0.203', 'port': 6379} + ] + self.best_ip = {'ip': None, 'port': None} # 最佳服务器 + + self.last_tick_dt = None # 记录该会话对象的最后一个tick时间 + + # 查询线程 + self.quotation_thread = None + + # ---------------------------------------------------------------------- + def ping(self, ip, port=6379): + """ + ping行情服务器 + :param ip: + :param port: + :param type_: + :return: + """ + __time1 = datetime.now() + try: + r = redis.Redis(host=ip, port=port, db=0, socket_connect_timeout=2) + r.set('ping', ip) + _timestamp = datetime.now() - __time1 + self.writeLog('Redis服务器{}:{},耗时:{}'.format(ip, port, _timestamp)) + return _timestamp + except: + self.writeError(u'Redis ping服务器,异常的响应{}'.format(ip)) + return timedelta(9, 9, 0) + + # ---------------------------------------------------------------------- + def select_best_ip(self): + """ + 选择行情服务器 + :return: + """ + self.writeLog(u'选择Redis行情服务器') + + data_future = [self.ping(x['ip'], x['port']) for x in self.ip_list] + + best_future_ip = self.ip_list[data_future.index(min(data_future))] + + self.writeLog(u'选取 {}:{}'.format( + best_future_ip['ip'], best_future_ip['port'])) + return best_future_ip + + def connect(self): + """ + 连接Redis行情服务器 + :return: + """ + if self.connection_status: + return + + self.writeLog(u'开始连接Redis行情服务器') + + try: + # 选取最佳服务器 + if self.best_ip['ip'] is None and self.best_ip['port'] is None: + self.best_ip = self.select_best_ip() + + # 创建Redis连接对象实例 + pool = redis.ConnectionPool(host=self.best_ip['ip'], port=self.best_ip['port']) + self.redis = redis.Redis(connection_pool=pool) + + self.redis.set('Connect', self.best_ip['ip']) + self.writeLog(u'连接Redis服务器{}:{}成功'.format(self.best_ip['ip'], self.best_ip['port'])) + + self.last_tick_dt = datetime.now() + self.connection_status = True + + # 查询线程 + self.quotation_thread = Thread(target=self.run, name="QuotationEngine.%s" % self.EventType) + self.quotation_thread.setDaemon(False) + # 调用run方法 + self.quotation_thread.start() + except Exception as ex: + self.writeLog(u'连接Redis服务器{}:{}异常:{},{}'.format(self.best_ip['ip'], self.best_ip['port'], + str(ex), traceback.format_exc())) + return + + def reconnect(self, i): + """ + 重连 + :param i: + :return: + """ + self.writeLog(u'开始重连redis服务器') + try: + # 选取最佳服务器 + self.best_ip = self.select_best_ip() + + # 创建Redis连接对象实例 + pool = redis.ConnectionPool(host=self.best_ip['ip'], port=self.best_ip['port']) + self.redis = redis.Redis(connection_pool=pool) + self.redis.set('Connect', self.best_ip['ip']) + self.writeLog(u'重新连接Redis服务器{}:{}成功'.format(self.best_ip['ip'], self.best_ip['port'])) + + self.last_tick_dt = datetime.now() + self.connection_status = True + + # 查询线程 + self.quotation_thread = Thread(target=self.run, name="QuotationEngine.%s" % self.EventType) + self.quotation_thread.setDaemon(False) + # 调用run方法 + self.quotation_thread.start() + + except Exception as ex: + self.writeLog(u'重新连接Redis服务器{}:{}异常:{},{}'.format(self.best_ip['ip'], self.best_ip['port'], + str(ex), traceback.format_exc())) + return + + def close(self): + """退出API""" + self.connection_status = False + + # ---------------------------------------------------------------------- + def subscribe(self, subscribeReq): + """订阅合约""" + + # 这里的设计是,如果尚未登录就调用了订阅方法 + # 则先保存订阅请求,登录完成后会自动订阅 + vn_symbol = str(subscribeReq.symbol) + vn_symbol = vn_symbol.upper() + + if vn_symbol[-2:] != '99': + self.writeLog(u'Redis行情订阅: {}不是指数合约,不能订阅'.format(vn_symbol)) + return + + redis_symbol_0 = vn_symbol.upper() + '.2' + redis_symbol_1 = vn_symbol.upper() + '.3' + self.writeLog(u'Redis行情订阅: {}=>{} & {}'.format(vn_symbol, redis_symbol_0, redis_symbol_1)) + + if redis_symbol_0 not in self.registed_symbol_set: + self.registed_symbol_set.add(redis_symbol_0) + if redis_symbol_1 not in self.registed_symbol_set: + self.registed_symbol_set.add(redis_symbol_1) + + self.checkStatus() + + def checkStatus(self): + """检查连接状态""" + + if len(self.registed_symbol_set) == 0: + return + + # 若还没有启动连接,就启动连接 + if self.last_tick_dt is None: + over_time = False + else: + over_time = ((datetime.now() - self.last_tick_dt).total_seconds() > 60) + + if not self.connection_status or over_time is True: + self.writeLog(u'Redis服务器{}:{} 还没有连接,启动连接'.format(self.best_ip['ip'], self.best_ip['port'])) + self.close() + self.connect() + + def run(self): + """ + :return: + """ + try: + last_dt = datetime.now() + self.writeLog(u'开始运行Redis行情服务器, {}'.format(last_dt)) + lastResult = dict() + while self.connection_status: + if len(self.registed_symbol_set)==0: + continue + # 从Redis中读取指数数据 + symbols = sorted(list(self.registed_symbol_set)) + results = self.redis.mget(symbols) + + # 只有当Tick发生改变时才推送到Gateway + # 比较2个源的数据, 选择最新&正确的那个 + # 1. 每个源判断接收到的数据是否异常, 异常则抛弃 + # - 波动幅度>10% + # - 价格异常 + # 2. 只有1个源有数据, 就选择那个源 + # 3. 如果2个源都有数据, 就选择时间更新的那个源 + for i in range(0, len(symbols), 2): + # 解码Redis返回值 + rt_tick_dict0 = None + rt_tick_dict1 = None + + if results[i] is not None and len(results[i]) > 0: + try: + rt_tick_dict0 = json.loads(str(results[i], 'utf-8')) + if len(rt_tick_dict0) == 0: + self.writeLog(u'redis[{}]: rt_list0 为空, {}'.format(symbols[i], results[i])) + rt_tick_dict0 = None + elif symbols[i] in lastResult and lastResult[symbols[i]] is not None: + if str(rt_tick_dict0) != str(lastResult[symbols[i]]): + change_percent = abs(float(lastResult[symbols[i]]['LastPrice']) - float(rt_tick_dict0['LastPrice'])) / float(lastResult[symbols[i]]['LastPrice']) * 100 + if change_percent >= 10: + self.writeLog(u'redis[{}]: rt_list0 数据异常, 变动>10%: {}% ({}, {})'.format(symbols[i], + change_percent, lastResult[symbols[i]]['LastPrice'], rt_tick_dict0['LastPrice'])) + rt_tick_dict0 = None + else: + # 无变动, 不需要更新 + rt_tick_dict0 = None + else: + lastResult[symbols[i]] = rt_tick_dict0 + except Exception as ex: + self.writeError(u'Redis行情服务器 run() exception:{},{}, rt_list0:{}不正确'.format(str(ex), + traceback.format_exc(),results[i])) + if results[i+1] is not None and len(results[i+1]) > 0: + try: + rt_tick_dict1 = json.loads(str(results[i + 1], 'utf-8')) + if len(rt_tick_dict1) == 0: + self.writeLog(u'redis[{}]: rt_list1 为空, {}'.format(symbols[i+1], results[i+1])) + rt_tick_dict1 = None + + elif symbols[i+1] in lastResult and lastResult[symbols[i+1]] is not None: + if str(rt_tick_dict1) != str(lastResult[symbols[i+1]]): + change_percent = abs(float(lastResult[symbols[i+1]]['LastPrice']) - float(rt_tick_dict1['LastPrice'])) / float(lastResult[symbols[i+1]]['LastPrice']) * 100 + if change_percent >= 10: + self.writeLog(u'redis[{}]: rt_list1 数据异常, 变动>10%: {}% ({}, {})'.format(symbols[i+1], + change_percent, lastResult[symbols[i+1]]['LastPrice'], rt_tick_dict1['LastPrice'])) + rt_tick_dict1 = None + else: + # 无变动, 不需要更新 + rt_tick_dict1 = None + else: + lastResult[symbols[i + 1]] = rt_tick_dict1 + + except Exception as ex: + self.writeError(u'Redis行情服务器 run() exception:{},{}, rt_list1:{}不正确' + .format(str(ex), traceback.format_exc(),results[i+1])) + + # 选择非空&时间较新的数据 + if rt_tick_dict0 is None and rt_tick_dict1 is None: + continue + + rt_tick_dict = None + if rt_tick_dict1 is None: + rt_tick_dict = rt_tick_dict0 + elif rt_tick_dict0 is None: + rt_tick_dict = rt_tick_dict1 + else: + millionseconds = str(rt_tick_dict0.get('UpdateMillisec',EMPTY_STRING)) + if len(millionseconds) > 6: + millionseconds = millionseconds[0:6] + tick_time = '{}.{}'.format(rt_tick_dict0['UpdateTime'], millionseconds) + rt_tick0_dt = datetime.strptime(rt_tick_dict0['TradingDay'] + ' ' + tick_time, '%Y%m%d %H:%M:%S.%f') + + millionseconds = str(rt_tick_dict1.get('UpdateMillisec',EMPTY_STRING)) + if len(millionseconds) > 6: + millionseconds = millionseconds[0:6] + tick_time = '{}.{}'.format(rt_tick_dict1['UpdateTime'], millionseconds) + rt_tick1_dt = datetime.strptime(rt_tick_dict1['TradingDay'] + ' ' + tick_time, '%Y%m%d %H:%M:%S.%f') + + if rt_tick0_dt < rt_tick1_dt: + rt_tick_dict = rt_tick_dict1 + else: + rt_tick_dict = rt_tick_dict0 + + self.processReq(symbols[i][0:-2], rt_tick_dict) + + # 等待下次查询 (500ms) + # self.writeLog(u'redis[{}] sleep'.format(i)) + sleep(self.req_interval) + dt = datetime.now() + if last_dt.minute != dt.minute: + self.writeLog('Redis行情服务器 check point. {}, process symbols:{}'.format(dt, symbols)) + last_dt = dt + except Exception as ex: + self.writeError(u'Redis行情服务器 run() exception:{},{}'.format(str(ex), traceback.format_exc())) + + self.writeError(u'Redis行情服务器在{}退出'.format(datetime.now())) + + def processReq(self, symbol, tick_dict): + """ + 处理行情信息ticker请求 + :param symbol: + :param req: + :return: + """ + if not isinstance(tick_dict,dict): + self.writeLog(u'行情tick不是dict:{}'.format(tick_dict)) + return + + self.last_tick_dt = datetime.now() + + # 忽略成交量为0的无效单合约tick数据 + if int(tick_dict.get('Volume', 0)) <= 0: + self.writeLog(u'Redis服务器{}:{} 忽略成交量为0的无效合约tick数据: {}' + .format(self.best_ip['ip'], self.best_ip['port'], tick_dict)) + + tick = VtTickData() + tick.gatewayName = self.gatewayName + + tick.symbol = symbol + if tick.symbol is None: + return + tick.exchange = tick_dict.get('ExchangeID') + tick.vtSymbol = tick.symbol + + short_symbol = tick.vtSymbol + short_symbol = short_symbol.replace('99', '').upper() + + # 使用本地时间 + tick.datetime = datetime.now() + # 修正毫秒 + last_tick = self.symbol_tick_dict.get(symbol, None) + if (last_tick is not None) and tick.datetime.replace(microsecond=0) == last_tick.datetime: + # 与上一个tick的时间(去除毫秒后)相同,修改为500毫秒 + tick.datetime = tick.datetime.replace(microsecond=500) + tick.time = tick.datetime.strftime('%H:%M:%S.%f')[0:12] + else: + tick.datetime = tick.datetime.replace(microsecond=0) + tick.time = tick.datetime.strftime('%H:%M:%S.%f')[0:12] + + tick.date = tick.datetime.strftime('%Y-%m-%d') + + # 生成TradingDay + # 正常日盘时间 + tick.tradingDay = tick.date + + # 修正夜盘的tradingDay + if tick.datetime.hour >= 20: + # 周一~周四晚上20点之后的tick,交易日属于第二天 + if tick.datetime.isoweekday() in [1, 2, 3, 4]: + trading_day = tick.datetime + timedelta(days=1) + tick.tradingDay = trading_day.strftime('%Y-%m-%d') + # 周五晚上20点之后的tick,交易日属于下周一 + elif tick.datetime.isoweekday() == 5: + trading_day = tick.datetime + timedelta(days=3) + tick.tradingDay = trading_day.strftime('%Y-%m-%d') + elif tick.datetime.hour < 3: + # 周六凌晨的tick,交易日属于下周一 + if tick.datetime.isoweekday() == 6: + trading_day = tick.datetime + timedelta(days=2) + tick.tradingDay = trading_day.strftime('%Y-%m-%d') + + # 排除非交易时间得tick + if tick.exchange is EXCHANGE_CFFEX: + if tick.datetime.hour not in [9, 10, 11, 13, 14, 15]: + return + if tick.datetime.hour == 9 and tick.datetime.minute < 15: + return + if tick.datetime.hour == 15 and tick.datetime.minute >= 15: + return + else: # 大商所/郑商所,上期所,上海能源 + # 排除非开盘小时 + if tick.datetime.hour in [3, 4, 5, 6, 7, 8, 12, 15, 16, 17, 18, 19, 20]: + return + # 排除早盘 10:15~10:30 + if tick.datetime.hour == 10 and 15 <= tick.datetime.minute < 30: + return + # 排除早盘 11:30~12:00 + if tick.datetime.hour == 11 and tick.datetime.minute >= 30: + return + # 排除午盘 13:00 ~13:30 + if tick.datetime.hour == 13 and tick.datetime.minute < 30: + return + # 排除凌晨2:30~3:00 + if tick.datetime.hour == 2 and tick.datetime.minute >= 30: + return + + # 排除大商所/郑商所夜盘数据 + if short_symbol in NIGHT_MARKET_ZZ: + if tick.datetime.hour == 23 and tick.datetime.minute >= 30: + return + if tick.datetime.hour in [0, 1, 2]: + return + + # 排除上期所夜盘数据 23:00 收盘 + if short_symbol in NIGHT_MARKET_DL or short_symbol in NIGHT_MARKET_SQ3: + if tick.datetime.hour in [23, 0, 1, 2]: + return + # 排除上期所夜盘数据 1:00 收盘 + if short_symbol in NIGHT_MARKET_SQ2: + if tick.datetime.hour in [1, 2]: + return + + # 排除日盘合约在夜盘得数据 + if short_symbol in MARKET_DAY_ONLY and (tick.datetime.hour < 9 or tick.datetime.hour > 16): + self.writeLog(u'Redis服务器{}:{} 排除日盘合约{}在夜盘得数据, {}'.format(self.best_ip['ip'], + self.best_ip['port'], short_symbol, tick_dict)) + return + + # 设置指数价格 + tick.preClosePrice = float(tick_dict.get('PreClosePrice', 0)) + tick.highPrice = float(tick_dict.get('HighestPrice', 0)) + tick.openPrice = float(tick_dict.get('OpenPrice', 0)) + tick.lowPrice = float(tick_dict.get('LowestPrice', 0)) + tick.lastPrice = float(tick_dict.get('LastPrice', 0)) + + tick.volume = int(tick_dict.get('Volume', 0)) + tick.openInterest = int(tick_dict.get('OpenInterest', 0)) + tick.lastVolume = 0 # 最新成交量 + tick.upperLimit = float(tick_dict.get('UpperLimitPrice', 0)) + tick.lowerLimit = float(tick_dict.get('LowerLimitPrice', 0)) + + # CTP只有一档行情 + if tick_dict.get('BidPrice1', 'nan') == 'nan': # 上期所有时会返回nan + tick.bidPrice1 = tick.lastPrice + else: + tick.bidPrice1 = float(tick_dict.get('BidPrice1', 0)) + tick.bidVolume1 = int(tick_dict.get('BidVolume1', 0)) + if tick_dict.get('AskPrice1', 'nan') == 'nan': # 上期所有时会返回nan + tick.askPrice1 = tick.lastPrice + else: + tick.askPrice1 = float(tick_dict.get('AskPrice1', 0)) + tick.askVolume1 = int(tick_dict.get('AskVolume1', 0)) + + tick.preOpenInterest = int(tick_dict.get('PreOpenInterest', 0)) # 昨持仓量 + + if self.gateway.debug_tick: + self.writeLog('Redis服务器{}[{}]:\n gateway:{},symbol:{},exch:{},vtsymbol:{},\n dt:{},td:{},' + 'lastPrice:{},volume:{},' + 'openPirce:{},highprice:{},lowPrice:{},preClosePrice:{},\nbid1:{},bv1:{},ask1:{},av1:{}' + .format(self.best_ip['ip'], + self.best_ip['port'], + tick.gatewayName, tick.symbol, + tick.exchange, tick.vtSymbol, + tick.datetime, tick.tradingDay, + tick.lastPrice,tick.volume, + tick.openPrice, tick.highPrice, + tick.lowPrice, tick.preClosePrice, + tick.bidPrice1, + tick.bidVolume1, tick.askPrice1, + tick.askVolume1)) + + self.symbol_tick_dict[symbol] = tick + self.gateway.onTick(tick) + self.gateway.onCustomerTick(tick) + + # ---------------------------------------------------------------------- + def writeLog(self, content): + """发出日志""" + log = VtLogData() + log.gatewayName = self.gatewayName + log.logContent = content + self.gateway.onLog(log) + + def writeError(self, content): + self.writeLog(content) + +class TickCombiner(object): + """ + Tick合成类 + """ + def __init__(self, gateway, setting): + self.gateway = gateway + + self.gateway.writeLog(u'创建tick合成类:{}'.format(setting)) + + self.vtSymbol = setting.get('vtSymbol',None) + self.leg1_symbol = setting.get('leg1_symbol',None) + self.leg2_symbol = setting.get('leg2_symbol',None) + self.leg1_ratio = setting.get('leg1_ratio', 1) # 腿1的数量配比 + self.leg2_ratio = setting.get('leg2_ratio', 1) # 腿2的数量配比 + self.minDiff = setting.get('minDiff',1) # 合成价差加比后的最小跳动 + # 价差 + self.is_spread = setting.get('is_spread', False) + # 价比 + self.is_ratio = setting.get('is_ratio', False) + + self.last_leg1_tick = None + self.last_leg2_tick = None + + # 价差日内最高/最低价 + self.spread_high = None + self.spread_low = None + + # 价比日内最高/最低价 + self.ratio_high = None + self.ratio_low = None + + # 当前交易日 + self.tradingDay = None + + if self.is_ratio and self.is_spread: + self.gateway.writeError(u'{}参数有误,不能同时做价差/加比.setting:{}'.format(self.vtSymbol,setting)) + return + + self.gateway.writeLog(u'初始化{}合成器成功'.format(self.vtSymbol)) + if self.is_spread: + self.gateway.writeLog(u'leg1:{} * {} - leg2:{} * {}'.format(self.leg1_symbol,self.leg1_ratio,self.leg2_symbol,self.leg2_ratio)) + if self.is_ratio: + self.gateway.writeLog(u'leg1:{} * {} / leg2:{} * {}'.format(self.leg1_symbol, self.leg1_ratio, self.leg2_symbol, + self.leg2_ratio)) + def onTick(self, tick): + """OnTick处理""" + combinable = False + + if tick.vtSymbol == self.leg1_symbol: + # leg1合约 + self.last_leg1_tick = tick + if self.last_leg2_tick is not None: + if self.last_leg1_tick.datetime.replace(microsecond=0) == self.last_leg2_tick.datetime.replace(microsecond=0): + combinable = True + + elif tick.vtSymbol == self.leg2_symbol: + # leg2合约 + self.last_leg2_tick = tick + if self.last_leg1_tick is not None: + if self.last_leg2_tick.datetime.replace(microsecond=0) == self.last_leg1_tick.datetime.replace(microsecond=0): + combinable = True + + # 不能合并 + if not combinable: + return + + if not self.is_ratio and not self.is_spread: + return + + # 以下情况,基本为单腿涨跌停,不合成价差/价格比 Tick + if (self.last_leg1_tick.askPrice1 == 0 or self.last_leg1_tick.bidPrice1 == self.last_leg1_tick.upperLimit)\ + and self.last_leg1_tick.askVolume1 == 0: + self.gateway.writeLog(u'leg1:{0}涨停{1},不合成价差Tick'.format(self.last_leg1_tick.vtSymbol,self.last_leg1_tick.bidPrice1)) + return + if (self.last_leg1_tick.bidPrice1 == 0 or self.last_leg1_tick.askPrice1 == self.last_leg1_tick.lowerLimit)\ + and self.last_leg1_tick.bidVolume1 == 0: + self.gateway.writeLog(u'leg1:{0}跌停{1},不合成价差Tick'.format(self.last_leg1_tick.vtSymbol, self.last_leg1_tick.askPrice1)) + return + if (self.last_leg2_tick.askPrice1 == 0 or self.last_leg2_tick.bidPrice1 == self.last_leg2_tick.upperLimit)\ + and self.last_leg2_tick.askVolume1 == 0: + self.gateway.writeLog(u'leg2:{0}涨停{1},不合成价差Tick'.format(self.last_leg2_tick.vtSymbol, self.last_leg2_tick.bidPrice1)) + return + if (self.last_leg2_tick.bidPrice1 == 0 or self.last_leg2_tick.askPrice1 == self.last_leg2_tick.lowerLimit)\ + and self.last_leg2_tick.bidVolume1 == 0: + self.gateway.writeLog(u'leg2:{0}跌停{1},不合成价差Tick'.format(self.last_leg2_tick.vtSymbol, self.last_leg2_tick.askPrice1)) + return + + if self.tradingDay != tick.tradingDay: + self.tradingDay = tick.tradingDay + self.spread_high = None + self.spread_low = None + self.ratio_high = None + self.ratio_low = None + + if self.is_spread: + spread_tick = VtTickData() + spread_tick.vtSymbol = self.vtSymbol + spread_tick.symbol = self.vtSymbol + spread_tick.gatewayName = tick.gatewayName + spread_tick.exchange = tick.exchange + spread_tick.tradingDay = tick.tradingDay + + spread_tick.datetime = tick.datetime + spread_tick.date = tick.date + spread_tick.time = tick.time + + # 叫卖价差=leg1.askPrice1 * 配比 - leg2.bidPrice1 * 配比,volume为两者最小 + spread_tick.askPrice1 = roundToPriceTick(priceTick=self.minDiff, + price =self.last_leg1_tick.askPrice1 * self.leg1_ratio - self.last_leg2_tick.bidPrice1 * self.leg2_ratio) + spread_tick.askVolume1 = min(self.last_leg1_tick.askVolume1, self.last_leg2_tick.bidVolume1) + + # 叫买价差=leg1.bidPrice1 * 配比 - leg2.askPrice1 * 配比,volume为两者最小 + spread_tick.bidPrice1 = roundToPriceTick(priceTick=self.minDiff, + price=self.last_leg1_tick.bidPrice1 * self.leg1_ratio - self.last_leg2_tick.askPrice1 * self.leg2_ratio) + spread_tick.bidVolume1 = min(self.last_leg1_tick.bidVolume1, self.last_leg2_tick.askVolume1) + + # 最新价 + spread_tick.lastPrice = roundToPriceTick(priceTick=self.minDiff, + price=(spread_tick.askPrice1 + spread_tick.bidPrice1)/2) + # 昨收盘价 + if self.last_leg2_tick.preClosePrice >0 and self.last_leg1_tick.preClosePrice > 0: + spread_tick.preClosePrice = roundToPriceTick(priceTick=self.minDiff, + price=self.last_leg1_tick.preClosePrice * self.leg1_ratio - self.last_leg2_tick.preClosePrice * self.leg2_ratio) + # 开盘价 + if self.last_leg2_tick.openPrice > 0 and self.last_leg1_tick.openPrice > 0: + spread_tick.openPrice = roundToPriceTick(priceTick=self.minDiff, + price=self.last_leg1_tick.openPrice * self.leg1_ratio - self.last_leg2_tick.openPrice * self.leg2_ratio) + # 最高价 + self.spread_high = spread_tick.askPrice1 if self.spread_high is None else max(self.spread_high,spread_tick.askPrice1) + spread_tick.highPrice = self.spread_high + + # 最低价 + self.spread_low = spread_tick.bidPrice1 if self.spread_low is None else min(self.spread_low, spread_tick.bidPrice1) + spread_tick.lowPrice = self.spread_low + + self.gateway.onTick(spread_tick) + + if self.is_ratio: + ratio_tick = VtTickData() + ratio_tick.vtSymbol = self.vtSymbol + ratio_tick.symbol = self.vtSymbol + ratio_tick.gatewayName = tick.gatewayName + ratio_tick.exchange = tick.exchange + ratio_tick.tradingDay = tick.tradingDay + ratio_tick.datetime = tick.datetime + ratio_tick.date = tick.date + ratio_tick.time = tick.time + + # 比率tick + ratio_tick.askPrice1 = roundToPriceTick(priceTick=self.minDiff, + price=100 * self.last_leg1_tick.askPrice1 * self.leg1_ratio / (self.last_leg2_tick.bidPrice1 * self.leg2_ratio)) + ratio_tick.askVolume1 = min(self.last_leg1_tick.askVolume1, self.last_leg2_tick.bidVolume1) + + ratio_tick.bidPrice1 = roundToPriceTick(priceTick=self.minDiff, + price=100 * self.last_leg1_tick.bidPrice1 * self.leg1_ratio/ (self.last_leg2_tick.askPrice1 * self.leg2_ratio)) + ratio_tick.bidVolume1 = min(self.last_leg1_tick.bidVolume1, self.last_leg2_tick.askVolume1) + ratio_tick.lastPrice = roundToPriceTick(priceTick=self.minDiff,price=(ratio_tick.askPrice1 + ratio_tick.bidPrice1) / 2) + + # 昨收盘价 + if self.last_leg2_tick.preClosePrice >0 and self.last_leg1_tick.preClosePrice > 0: + ratio_tick.preClosePrice = roundToPriceTick(priceTick=self.minDiff, + price=100*self.last_leg1_tick.preClosePrice * self.leg1_ratio/(self.last_leg2_tick.preClosePrice * self.leg2_ratio)) + # 开盘价 + if self.last_leg2_tick.openPrice > 0 and self.last_leg1_tick.openPrice > 0: + ratio_tick.openPrice = roundToPriceTick(priceTick=self.minDiff, + price=100 * self.last_leg1_tick.openPrice * self.leg1_ratio / (self.last_leg2_tick.openPrice * self.leg2_ratio)) + # 最高价 + self.ratio_high = ratio_tick.askPrice1 if self.ratio_high is None else max(self.ratio_high, + ratio_tick.askPrice1) + ratio_tick.highPrice = self.spread_high + + # 最低价 + self.ratio_low = ratio_tick.bidPrice1 if self.ratio_low is None else min(self.ratio_low, + ratio_tick.bidPrice1) + ratio_tick.lowPrice = self.spread_low + + self.gateway.onTick(ratio_tick) + #---------------------------------------------------------------------- def test(): """测试""" diff --git a/vnpy/trader/gateway/ctpseGateway/ctpseGateway.py b/vnpy/trader/gateway/ctpseGateway/ctpseGateway.py index ecc9e759..bd2aa40b 100644 --- a/vnpy/trader/gateway/ctpseGateway/ctpseGateway.py +++ b/vnpy/trader/gateway/ctpseGateway/ctpseGateway.py @@ -167,6 +167,7 @@ class CtpseGateway(VtGateway): mdAddress = str(setting['mdAddress']) self.debug_tick = setting.get('debug_tick',False) + self.debug = setting.get('debug',False) # 如果json文件提供了验证码 if 'authCode' in setting: @@ -494,8 +495,7 @@ class CtpMdApi(MdApi): def onHeartBeatWarning(self, n): """心跳报警""" # 因为API的心跳报警比较常被触发,且与API工作关系不大,因此选择忽略 - if getattr(self.gateway,'debug',False): - print('onHeartBeatWarning') + pass #---------------------------------------------------------------------- def onRspError(self, error, n, last): @@ -554,15 +554,12 @@ class CtpMdApi(MdApi): def onRspSubMarketData(self, data, error, n, last): """订阅合约回报""" # 通常不在乎订阅错误,选择忽略 - if getattr(self.gateway, 'debug', False): - print('onRspSubMarketData') #---------------------------------------------------------------------- def onRspUnSubMarketData(self, data, error, n, last): """退订合约回报""" # 同上 - if getattr(self.gateway, 'debug', False): - print('onRspUnSubMarketData') + pass #---------------------------------------------------------------------- def onRtnDepthMarketData(self, data): @@ -572,8 +569,6 @@ class CtpMdApi(MdApi): # self.writeLog(u'忽略成交量为0的无效单合约tick数据:') # self.writeLog(data) # return - if getattr(self.gateway, 'debug', False): - print('onRtnDepthMarketData') if not self.connectionStatus: self.connectionStatus = True @@ -683,29 +678,25 @@ class CtpMdApi(MdApi): self.gateway.onCustomerTick(tick) - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def onRspSubForQuoteRsp(self, data, error, n, last): """订阅期权询价""" - if getattr(self.gateway, 'debug', False): - print('onRspSubForQuoteRsp') - - #---------------------------------------------------------------------- - def onRspUnSubForQuoteRsp(self, data, error, n, last): - """退订期权询价""" - if getattr(self.gateway, 'debug', False): - print('onRspUnSubForQuoteRsp') - - #---------------------------------------------------------------------- - def onRtnForQuoteRsp(self, data): - """期权询价推送""" - if getattr(self.gateway, 'debug', False): - print('onRtnForQuoteRsp') + pass #---------------------------------------------------------------------- + def onRspUnSubForQuoteRsp(self, data, error, n, last): + """退订期权询价""" + pass + + # ---------------------------------------------------------------------- + def onRtnForQuoteRsp(self, data): + """期权询价推送""" + pass + + # ---------------------------------------------------------------------- def connect(self, userID, password, brokerID, address): """初始化连接""" - if getattr(self.gateway, 'debug', False): - print('connect') + self.userID = userID # 账号 self.password = password # 密码 self.brokerID = brokerID # 经纪商代码 @@ -718,7 +709,7 @@ class CtpMdApi(MdApi): if not os.path.exists(path): os.makedirs(path) self.createFtdcMdApi(path) - + self.writeLog(u'注册行情服务器地址:{}'.format(self.address)) # 注册服务器地址 self.registerFront(self.address) @@ -732,37 +723,43 @@ class CtpMdApi(MdApi): if not self.loginStatus: self.login() - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def subscribe(self, subscribeReq): """订阅合约""" # 这里的设计是,如果尚未登录就调用了订阅方法 # 则先保存订阅请求,登录完成后会自动订阅 - if getattr(self.gateway, 'debug', False): - print('subscribe') + + + if self.connectionStatus and not self.loginStatus: + self.login() # 订阅传统合约 self.subscribeMarketData(str(subscribeReq.symbol)) self.writeLog(u'订阅合约:{0}'.format(str(subscribeReq.symbol))) self.subscribedSymbols.add(subscribeReq) - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def login(self): """登录""" + # 如果填入了用户名密码等,则登录 if self.userID and self.password and self.brokerID: + self.writeLog(u'登入行情服务器') req = {} req['UserID'] = self.userID req['Password'] = self.password req['BrokerID'] = self.brokerID self.reqID += 1 - self.reqUserLogin(req, self.reqID) + self.reqUserLogin(req, self.reqID) + else: + self.writeLog(u'未配置用户/密码,不登录行情服务器') - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def close(self): """关闭""" self.exit() - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def writeLog(self, content): """发出日志""" log = VtLogData() @@ -809,8 +806,6 @@ class CtpTdApi(TdApi): #---------------------------------------------------------------------- def onFrontConnected(self): """服务器连接""" - if getattr(self.gateway, 'debug', False): - print('onFrontConnected') self.connectionStatus = True self.writeLog(text.TRADING_SERVER_CONNECTED) @@ -822,8 +817,6 @@ class CtpTdApi(TdApi): #---------------------------------------------------------------------- def onFrontDisconnected(self, n): """服务器断开""" - if getattr(self.gateway, 'debug', False): - print('onFrontDisconnected') self.connectionStatus = False self.loginStatus = False self.gateway.tdConnected = False @@ -833,15 +826,10 @@ class CtpTdApi(TdApi): #---------------------------------------------------------------------- def onHeartBeatWarning(self, n): """""" - if getattr(self.gateway, 'debug', False): - print('onHeartBeatWarning') #---------------------------------------------------------------------- def onRspAuthenticate(self, data, error, n, last): """验证客户端回报""" - if getattr(self.gateway, 'debug', False): - print('onRspAuthenticate') - if error['ErrorID'] == 0: self.authStatus = True self.writeLog(text.TRADING_SERVER_AUTHENTICATED) @@ -852,8 +840,6 @@ class CtpTdApi(TdApi): #---------------------------------------------------------------------- def onRspUserLogin(self, data, error, n, last): """登陆回报""" - if getattr(self.gateway, 'debug', False): - print('onRspUserLogin') # 如果登录成功,推送日志信息 if error['ErrorID'] == 0: self.tradingDay = str(data['TradingDay']) @@ -887,16 +873,12 @@ class CtpTdApi(TdApi): def resentReqQryInstrument(self): # 查询合约代码 - if getattr(self.gateway, 'debug', False): - print('resentReqQryInstrument') self.reqID += 1 self.reqQryInstrument({}, self.reqID) #---------------------------------------------------------------------- def onRspUserLogout(self, data, error, n, last): """登出回报""" - if getattr(self.gateway, 'debug', False): - print('onRspUserLogout') # 如果登出成功,推送日志信息 if error['ErrorID'] == 0: self.loginStatus = False @@ -915,20 +897,16 @@ class CtpTdApi(TdApi): #---------------------------------------------------------------------- def onRspUserPasswordUpdate(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspUserPasswordUpdate') + pass #---------------------------------------------------------------------- def onRspTradingAccountPasswordUpdate(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspTradingAccountPasswordUpdate') + pass #---------------------------------------------------------------------- def onRspOrderInsert(self, data, error, n, last): """发单错误(柜台)""" - if getattr(self.gateway, 'debug', False): - print('onRspOrderInsert') # 推送委托信息 order = VtOrderData() order.gatewayName = self.gatewayName @@ -956,20 +934,16 @@ class CtpTdApi(TdApi): #---------------------------------------------------------------------- def onRspParkedOrderInsert(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspParkedOrderInsert') + pass #---------------------------------------------------------------------- def onRspParkedOrderAction(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspParkedOrderAction') + pass #---------------------------------------------------------------------- def onRspOrderAction(self, data, error, n, last): """撤单错误(柜台)""" - if getattr(self.gateway, 'debug', False): - print('onRspOrderAction') try: symbol = data['InstrumentID'] except KeyError: @@ -985,89 +959,77 @@ class CtpTdApi(TdApi): #---------------------------------------------------------------------- def onRspQueryMaxOrderVolume(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQueryMaxOrderVolume') + pass #---------------------------------------------------------------------- def onRspSettlementInfoConfirm(self, data, error, n, last): """确认结算信息回报""" - if getattr(self.gateway, 'debug', False): - print('onRspSettlementInfoConfirm') self.writeLog(text.SETTLEMENT_INFO_CONFIRMED) - #---------------------------------------------------------------------- def onRspRemoveParkedOrder(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspSettlementInfoConfirm') + pass #---------------------------------------------------------------------- def onRspRemoveParkedOrderAction(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspRemoveParkedOrderAction') + pass #---------------------------------------------------------------------- def onRspExecOrderInsert(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspExecOrderInsert') + pass #---------------------------------------------------------------------- def onRspExecOrderAction(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspExecOrderAction') + pass #---------------------------------------------------------------------- def onRspForQuoteInsert(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspForQuoteInsert') + pass #---------------------------------------------------------------------- def onRspQuoteInsert(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQuoteInsert') + pass #---------------------------------------------------------------------- def onRspQuoteAction(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQuoteAction') + pass #---------------------------------------------------------------------- def onRspLockInsert(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspLockInsert') + pass #---------------------------------------------------------------------- def onRspCombActionInsert(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspCombActionInsert') + pass #---------------------------------------------------------------------- def onRspQryOrder(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspCombActionInsert') + pass #---------------------------------------------------------------------- def onRspQryTrade(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryTrade') + pass #---------------------------------------------------------------------- def onRspQryInvestorPosition(self, data, error, n, last): """持仓查询回报""" - if getattr(self.gateway, 'debug', False): + if self.gateway.debug: print('onRspQryInvestorPosition') + print(u'data:{}'.format(data)) + print(u'error:{}'.format(error)) + print('n:{},last:{}'.format(n,last)) if not data['InstrumentID']: return @@ -1075,64 +1037,82 @@ class CtpTdApi(TdApi): if not self.gateway.tdConnected: self.gateway.tdConnected = True - # 获取持仓缓存对象 - posName = '.'.join([data['InstrumentID'], data['PosiDirection']]) - if posName in self.posDict: - pos = self.posDict[posName] - else: - pos = VtPositionData() - self.posDict[posName] = pos - - pos.gatewayName = self.gatewayName - pos.symbol = data['InstrumentID'] - pos.vtSymbol = pos.symbol - pos.direction = posiDirectionMapReverse.get(data['PosiDirection'], '') - pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction, pos.gatewayName]) + try: + # 获取持仓缓存对象 + posName = '.'.join([data['InstrumentID'], data['PosiDirection']]) + if posName in self.posDict: + pos = self.posDict[posName] + else: + pos = VtPositionData() + self.posDict[posName] = pos - exchange = self.symbolExchangeDict.get(pos.symbol, EXCHANGE_UNKNOWN) + pos.gatewayName = self.gatewayName + pos.symbol = data['InstrumentID'] + pos.vtSymbol = pos.symbol + pos.direction = posiDirectionMapReverse.get(data['PosiDirection'], '') + pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction, pos.gatewayName]) - # 针对上期所持仓的今昨分条返回(有昨仓、无今仓),读取昨仓数据 - if exchange == EXCHANGE_SHFE: - if data['YdPosition'] and not data['TodayPosition']: - pos.ydPosition = data['Position'] - # 否则基于总持仓和今持仓来计算昨仓数据 - else: - pos.ydPosition = data['Position'] - data['TodayPosition'] + exchange = self.symbolExchangeDict.get(pos.symbol, EXCHANGE_UNKNOWN) - # 计算成本 - if pos.symbol not in self.symbolSizeDict: - return - size = self.symbolSizeDict[pos.symbol] - cost = pos.price * pos.position * size + # 针对上期所持仓的今昨分条返回(有昨仓、无今仓),读取昨仓数据 + if exchange == EXCHANGE_SHFE: + if data['YdPosition'] and not data['TodayPosition']: + pos.ydPosition = data['Position'] + # 否则基于总持仓和今持仓来计算昨仓数据 + else: + pos.ydPosition = data['Position'] - data['TodayPosition'] - # 汇总总仓 - pos.position += data['Position'] - pos.positionProfit += data['PositionProfit'] + # 计算成本 + if pos.symbol not in self.symbolSizeDict: + return + size = self.symbolSizeDict[pos.symbol] + cost = pos.price * pos.position * size - # 计算持仓均价 - if pos.position and size: - pos.price = (cost + data['PositionCost']) / (pos.position * size) + # 汇总总仓 + pos.position += data['Position'] - # 读取冻结 - if pos.direction is DIRECTION_LONG: - pos.frozen += data['LongFrozen'] - else: - pos.frozen += data['ShortFrozen'] + # 计算持仓均价 + if pos.position and size: + #pos.price = (cost + data['PositionCost']) / (pos.position * size) + pos.price = (cost + data['OpenCost']) / (pos.position * size) + + # 上一交易日结算价 + pre_settlement_price = data['PreSettlementPrice'] + # 开仓均价 + open_cost_price = data['OpenCost'] / (pos.position * size) + # 逐笔盈亏 = (上一交易日结算价 - 开仓价)* 持仓数量 * 杠杆 + 当日持仓收益 + if pos.direction == DIRECTION_LONG: + pre_profit = (pre_settlement_price - open_cost_price) * (pos.position * size) + else: + pre_profit = (open_cost_price - pre_settlement_price) * (pos.position * size) + + pos.positionProfit = pos.positionProfit + pre_profit + data['PositionProfit'] + + # 读取冻结 + if pos.direction is DIRECTION_LONG: + pos.frozen += data['LongFrozen'] + else: + pos.frozen += data['ShortFrozen'] + + # 查询回报结束 + if last: + # 遍历推送 + if self.gateway.debug: + print(u'最后推送') + for pos in list(self.posDict.values()): + self.gateway.onPosition(pos) + + # 清空缓存 + self.posDict.clear() + except Exception as ex: + self.gateway.writeError('onRspQryInvestorPosition exception:{}'.format(str(ex))) + self.gateway.writeError('traceL{}'.format(traceback.format_exc())) - # 查询回报结束 - if last: - # 遍历推送 - for pos in list(self.posDict.values()): - self.gateway.onPosition(pos) - - # 清空缓存 - self.posDict.clear() #---------------------------------------------------------------------- def onRspQryTradingAccount(self, data, error, n, last): """资金账户查询回报""" - if getattr(self.gateway, 'debug', False): - print('onRspQryTradingAccount') + self.gateway.mdConnected = True account = VtAccountData() @@ -1162,14 +1142,12 @@ class CtpTdApi(TdApi): #---------------------------------------------------------------------- def onRspQryInvestor(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryInvestor') + pass #---------------------------------------------------------------------- def onRspQryTradingCode(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryTradingCode') + pass #---------------------------------------------------------------------- def onRspQryInstrumentMarginRate(self, data, error, n, last): @@ -1181,32 +1159,26 @@ class CtpTdApi(TdApi): :param last: :return: """ - if getattr(self.gateway, 'debug', False): - print('onRspQryInstrumentMarginRate') + pass #---------------------------------------------------------------------- def onRspQryInstrumentCommissionRate(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryInstrumentCommissionRate') + pass #---------------------------------------------------------------------- def onRspQryExchange(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryExchange') + pass #---------------------------------------------------------------------- def onRspQryProduct(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryProduct') + pass #---------------------------------------------------------------------- def onRspQryInstrument(self, data, error, n, last): """合约查询回报""" - if getattr(self.gateway, 'debug', False): - print('onRspQryInstrument') self.gateway.mdConnected = True contract = VtContractData() @@ -1260,74 +1232,61 @@ class CtpTdApi(TdApi): #---------------------------------------------------------------------- def onRspQryDepthMarketData(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryDepthMarketData') + pass #---------------------------------------------------------------------- def onRspQrySettlementInfo(self, data, error, n, last): """查询结算信息回报""" - if getattr(self.gateway, 'debug', False): - print('onRspQryDepthMarketData') - + pass #---------------------------------------------------------------------- def onRspQryTransferBank(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryTransferBank') + pass #---------------------------------------------------------------------- def onRspQryInvestorPositionDetail(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryInvestorPositionDetail') + pass #---------------------------------------------------------------------- def onRspQryNotice(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryInvestorPositionDetail') + pass #---------------------------------------------------------------------- def onRspQrySettlementInfoConfirm(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQrySettlementInfoConfirm') + pass #---------------------------------------------------------------------- def onRspQryInvestorPositionCombineDetail(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryInvestorPositionCombineDetail') + pass #---------------------------------------------------------------------- def onRspQryCFMMCTradingAccountKey(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryCFMMCTradingAccountKey') + pass #---------------------------------------------------------------------- def onRspQryEWarrantOffset(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryEWarrantOffset') + pass #---------------------------------------------------------------------- def onRspQryInvestorProductGroupMargin(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryInvestorProductGroupMargin') + pass #---------------------------------------------------------------------- def onRspQryExchangeMarginRate(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryExchangeMarginRate') + pass #---------------------------------------------------------------------- def onRspQryExchangeMarginRateAdjust(self, data, error, n, last): """""" - if getattr(self.gateway, 'debug', False): - print('onRspQryExchangeMarginRateAdjust') + pass #---------------------------------------------------------------------- def onRspQryExchangeRate(self, data, error, n, last): diff --git a/vnpy/trader/vtEngine.py b/vnpy/trader/vtEngine.py index 5451685f..81ca1c91 100644 --- a/vnpy/trader/vtEngine.py +++ b/vnpy/trader/vtEngine.py @@ -1,6 +1,8 @@ # encoding: UTF-8 -print(u'启动load vtEngine.py') +# 20190318 增加自定义套利合约Customer_Contract + +print('load vtEngine.py') import shelve from collections import OrderedDict @@ -13,14 +15,18 @@ from pymongo.errors import ConnectionFailure,AutoReconnect from vnpy.trader.vtEvent import Event as vn_event from vnpy.trader.language import text -#from vnpy.trader.app.ctaStrategy.ctaEngine import CtaEngine -#from vnpy.trader.app.dataRecorder.drEngine import DrEngine -#from vnpy.trader.app.riskManager.rmEngine import RmEngine -from vnpy.trader.vtFunction import loadMongoSetting, getTempPath + +from vnpy.trader.vtFunction import loadMongoSetting, getTempPath,getFullSymbol,getShortSymbol,getJsonPath from vnpy.trader.vtGateway import * -from vnpy.trader.app import (ctaStrategy,cmaStrategy, riskManager) +from vnpy.trader.app import (ctaStrategy, riskManager) from vnpy.trader.setup_logger import setup_logger +from vnpy.trader.vtGlobal import globalSetting import traceback +from datetime import datetime, timedelta, time, date + +from vnpy.trader.vtConstant import (DIRECTION_LONG, DIRECTION_SHORT, + OFFSET_OPEN, OFFSET_CLOSE, OFFSET_CLOSETODAY, + STATUS_ALLTRADED, STATUS_CANCELLED, STATUS_REJECTED,STATUS_UNKNOWN) import psutil try: @@ -67,6 +73,7 @@ class MainEngine(object): self.ctaEngine = None # CtaEngine(self, self.eventEngine) # cta策略运行模块 self.drEngine = None # DrEngine(self, self.eventEngine) # 数据记录模块 self.rmEngine = None # RmEngine(self, self.eventEngine) # 风险管理模块 + self.algoEngine = None # 算法交易模块 self.cmaEngine = None # 跨市场套利引擎 self.connected_gw_names = [] @@ -76,6 +83,9 @@ class MainEngine(object): self.createLogger() + self.spd_id = 1 # 自定义套利的委托编号 + self.algo_order_dict = {} # 记录算法交易的委托编号,便于通过算法引擎撤单 + # ---------------------------------------------------------------------- def addGateway(self, gatewayModule,gateway_name=EMPTY_STRING): """添加底层接口""" @@ -119,7 +129,9 @@ class MainEngine(object): self.ctaEngine = self.appDict[appName] elif appName == riskManager.appName: self.rmEngine = self.appDict[appName] - elif appName == cmaStrategy.appName: + elif appName == 'AlgoTrading': + self.algoEngine = self.appDict[appName] + elif appName == 'CrossMarketArbitrage': self.cmaEngine = self.appDict[appName] # 保存应用信息 @@ -198,6 +210,9 @@ class MainEngine(object): def subscribe(self, subscribeReq, gatewayName): """订阅特定接口的行情""" # 处理没有输入gatewayName的情况 + if len(subscribeReq.symbol) == 0: + return + if gatewayName is None or len(gatewayName) == 0: if len(self.connected_gw_names) == 0: self.writeError(u'vtEngine.subscribe, no connected gateway') @@ -214,8 +229,11 @@ class MainEngine(object): self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName)) # ---------------------------------------------------------------------- - def sendOrder(self, orderReq, gatewayName): - """对特定接口发单""" + def sendOrder(self, orderReq, gatewayName, strategyName=None): + """ + 对特定接口发单 + strategyName: ctaEngine中,发单的策略实例名称 + """ # 如果风控检查失败则不发单 if self.rmEngine and not self.rmEngine.checkRisk(orderReq): self.writeCritical(u'风控检查不通过,gw:{},{} {} {} p:{} v:{}'.format(gatewayName, orderReq.direction, orderReq.offset, orderReq.symbol, orderReq.price, orderReq.volume)) @@ -229,21 +247,145 @@ class MainEngine(object): orderReq.symbol, orderReq.price, orderReq.volume)) return '' + # 判断是否包含gatewayName + if gatewayName is None or len(gatewayName)==0: + if len(self.connected_gw_names) == 1: + gatewayName = self.connected_gw_names[0] + else: + self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName)) + return '' + else: + if gatewayName not in self.connected_gw_names: + self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName)) + return '' + + # 如果合约在自定义清单中,并且包含SPD,则使用算法交易, 而不是CtpGateway发单 + if orderReq.symbol.endswith('SPD') and orderReq.symbol in self.dataEngine.custom_contract_setting: + return self.sendAlgoOrder(orderReq, gatewayName, strategyName) + if gatewayName in self.gatewayDict: gateway = self.gatewayDict[gatewayName] return gateway.sendOrder(orderReq) - else: - self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName)) - + + def sendAlgoOrder(self,orderReq, gatewayName, strategyName=None): + """发送算法交易指令""" + self.writeLog(u'创建算法交易,gatewayName:{},strategyName:{},symbol:{},price:{},volume:{}' + .format(gatewayName, strategyName,orderReq.vtSymbol,orderReq.price,orderReq.volume)) + + # 创建一个Order事件 + order = VtOrderData() + order.vtSymbol = orderReq.vtSymbol + order.symbol = orderReq.symbol + order.exchange = orderReq.exchange + order.gatewayName = gatewayName + order.direction = orderReq.direction + order.offset = orderReq.offset + order.price = orderReq.price + order.totalVolume = orderReq.volume + order.tradedVolume = 0 + order.orderTime = datetime.now().strftime('%H:%M:%S.%f') + order.orderID = 'spd_{}'.format(self.spd_id) + self.spd_id += 1 + order.vtOrderID = gatewayName+'.'+order.orderID + + #如果算法引擎未启动,发出拒单事件 + if self.algoEngine is None: + try: + self.writeLog(u'算法引擎未启动,启动ing') + from vnpy.trader.app import algoTrading + self.addApp(algoTrading) + except Exception as ex: + self.writeError(u'算法引擎未加载,不能创建算法交易') + order.cancelTime = datetime.now().strftime('%H:%M:%S.%f') + order.status = STATUS_REJECTED + event1 = Event(type_=EVENT_ORDER) + event1.dict_['data'] = order + self.eventEngine.put(event1) + return '' + + # 创建算法实例,由算法引擎启动 + tradeCommand = '' + if orderReq.direction == DIRECTION_LONG and orderReq.offset == OFFSET_OPEN: + tradeCommand = 'Buy' + elif orderReq.direction == DIRECTION_SHORT and orderReq.offset == OFFSET_OPEN: + tradeCommand = 'Short' + elif orderReq.direction == DIRECTION_SHORT and orderReq.offset in [OFFSET_CLOSE, OFFSET_CLOSETODAY]: + tradeCommand = 'Sell' + elif orderReq.direction == DIRECTION_LONG and orderReq.offset in [OFFSET_CLOSE, OFFSET_CLOSETODAY]: + tradeCommand = 'Cover' + + contract_setting = self.dataEngine.custom_contract_setting.get(orderReq.symbol,{}) + algo = { + 'templateName': u'SpreadTrading 套利', + 'order_vtSymbol': orderReq.symbol, + 'order_command': tradeCommand, + 'order_price': orderReq.price, + 'order_volume': orderReq.volume, + 'timer_interval': 120, + 'strategy_name': strategyName + } + algo.update(contract_setting) + + # 算法引擎 + algoName = self.algoEngine.addAlgo(algo) + self.writeLog(u'sendAlgoOrder(): addAlgo {}={}'.format(algoName, str(algo))) + + order.status = STATUS_UNKNOWN + order.orderID = algoName + order.vtOrderID = gatewayName + u'.' + algoName + + event1 = Event(type_=EVENT_ORDER) + event1.dict_['data'] = order + self.eventEngine.put(event1) + + # 登记在本地的算法委托字典中 + self.algo_order_dict.update({algoName: {'algo': algo, 'order': order}}) + + return gatewayName + u'.' + algoName + # ---------------------------------------------------------------------- def cancelOrder(self, cancelOrderReq, gatewayName): """对特定接口撤单""" - if gatewayName in self.gatewayDict: + if cancelOrderReq.orderID in self.algo_order_dict: + self.writeLog(u'执行算法实例撤单') + self.cancelAlgoOrder(cancelOrderReq,gatewayName) + return + + if gatewayName in self.gatewayDict and gatewayName in self.connected_gw_names: gateway = self.gatewayDict[gatewayName] gateway.cancelOrder(cancelOrderReq) else: self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName)) + def cancelAlgoOrder(self, cancelOrderReq, gatewayName): + if self.algoEngine is None: + self.writeError(u'算法引擎未实例化,不能撤单:{}'.format(cancelOrderReq.orderID)) + return + try: + d = self.algo_order_dict.get(cancelOrderReq.orderID) + algo = d.get('algo',None) + order = d.get('order', None) + + if algo is None or order is None: + self.writeError(u'未找到算法配置和委托单,不能撤单:{}'.format(cancelOrderReq.orderID)) + return + + if cancelOrderReq.orderID not in self.algoEngine.algoDict: + self.writeError(u'算法实例不存在,不能撤单') + return + + self.algoEngine.stopAlgo(cancelOrderReq.orderID) + + order.cancelTime = datetime.now().strftime('%H:%M:%S.%f') + order.status = STATUS_CANCELLED + event1 = Event(type_=EVENT_ORDER) + event1.dict_['data'] = order + self.eventEngine.put(event1) + + except Exception as ex: + self.writeError(u'算法实例撤销异常:{}\n{}'.format(str(ex),traceback.format_exc())) + + # ---------------------------------------------------------------------- def qryAccount(self, gatewayName): """查询特定接口的账户""" @@ -330,7 +472,7 @@ class MainEngine(object): return True except Exception as ex: - print( u'vtEngine.disconnect Exception:{0} '.format(str(ex))) + self.writeError(u'vtEngine.disconnect Exception:{0} '.format(str(ex))) return False # ---------------------------------------------------------------------- @@ -363,7 +505,7 @@ class MainEngine(object): filename = os.path.abspath(os.path.join(path, 'vnpy')) - print( u'create logger:{}'.format(filename)) + print(u'create logger:{}'.format(filename)) self.logger = setup_logger(filename=filename, name='vnpy', debug=True) # ---------------------------------------------------------------------- @@ -381,15 +523,22 @@ class MainEngine(object): else: self.createLogger() - # 发出邮件/微信 - #try: - # if len(self.gatewayDetailList) > 0: - # target = self.gatewayDetailList[0]['gatewayName'] - # else: - # target = WECHAT_GROUP["DEBUG_01"] - # sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_ERROR) - #except Exception as ex: - # print(u'send wechat exception:{}'.format(str(ex)),file=sys.stderr) + # 发出邮件 + if globalSetting.get('activate_email', False): + try: + sendmail(subject=u'{0} Error'.format('_'.join(self.connected_gw_names)), msgcontent=content) + except Exception as ex: + print(u'vtEngine.writeError sendmail Exception:{}'.format(str(ex)), file=sys.stderr) + print(u'{}'.format(traceback.format_exc()), file=sys.stderr) + + # 发出微信 + if globalSetting.get('activate_wx_ft', False): + try: + from huafu.util.util_wx_ft import sendWxMsg + sendWxMsg(text=content) + except Exception as ex: + print(u'vtEngine.writeError sendWxMsg Exception:{}'.format(str(ex)), file=sys.stderr) + print(u'{}'.format(traceback.format_exc()), file=sys.stderr) # ---------------------------------------------------------------------- def writeWarning(self, content): @@ -411,20 +560,22 @@ class MainEngine(object): self.createLogger() # 发出邮件 - try: - sendmail(subject=u'{0} Warning'.format('_'.join(self.connected_gw_names)), msgcontent=content) - except: - pass + if globalSetting.get('activate_email', False): + try: + sendmail(subject=u'{0} Warning'.format('_'.join(self.connected_gw_names)), msgcontent=content) + except Exception as ex: + print(u'vtEngine.writeWarning sendmail Exception:{}'.format(str(ex)), file=sys.stderr) + print(u'{}'.format(traceback.format_exc()), file=sys.stderr) # 发出微信 - #try: - # if len(self.gatewayDetailList) > 0: - # target = self.gatewayDetailList[0]['gatewayName'] - # else: - # target = WECHAT_GROUP["DEBUG_01"] - # sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_WARNING) - #except Exception as ex: - # print(u'send wechat exception:{}'.format(str(ex)), file=sys.stderr) + if globalSetting.get('activate_wx_ft', False): + try: + from huafu.util.util_wx_ft import sendWxMsg + sendWxMsg(text=content) + except Exception as ex: + print(u'vtEngine.writeWarning sendWxMsg Exception:{}'.format(str(ex)), file=sys.stderr) + print(u'{}'.format(traceback.format_exc()), file=sys.stderr) + # ---------------------------------------------------------------------- def writeNotification(self, content): @@ -436,20 +587,21 @@ class MainEngine(object): self.eventEngine.put(event) # 发出邮件 - try: - sendmail(subject=u'{0} Notification'.format('_'.join(self.connected_gw_names)), msgcontent=content) - except: - pass + if globalSetting.get('activate_email', False): + try: + sendmail(subject=u'{0} Notification'.format('_'.join(self.connected_gw_names)), msgcontent=content) + except Exception as ex: + print(u'vtEngine.writeNotification sendmail Exception:{}'.format(str(ex)), file=sys.stderr) + print(u'{}'.format(traceback.format_exc()), file=sys.stderr) # 发出微信 - # try: - # if len(self.gatewayDetailList) > 0: - # target = self.gatewayDetailList[0]['gatewayName'] - # else: - # target = WECHAT_GROUP["DEBUG_01"] - # sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_INFO) - # except Exception as ex: - # print(u'send wechat exception:{}'.format(str(ex)), file=sys.stderr) + if globalSetting.get('activate_wx_ft', False): + try: + from huafu.util.util_wx_ft import sendWxMsg + sendWxMsg(text=content) + except Exception as ex: + print(u'vtEngine.writeNotification sendWxMsg Exception:{}'.format(str(ex)), file=sys.stderr) + print(u'{}'.format(traceback.format_exc()), file=sys.stderr) # ---------------------------------------------------------------------- def writeCritical(self, content): @@ -471,23 +623,22 @@ class MainEngine(object): self.createLogger() # 发出邮件 - try: - sendmail(subject=u'{0} Critical'.format('_'.join(self.connected_gw_names)), msgcontent=content) - from vnpy.trader.util_wx_ft import sendWxMsg - sendWxMsg(text=content,desp='Critical error') - except: - pass - - ## 发出微信 - #try: - # # if len(self.gatewayDetailList) > 0: - # target = self.gatewayDetailList[0]['gatewayName'] - # else: - # target = WECHAT_GROUP["DEBUG_01"] - # sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_FATAL) - #except: - # pass + if globalSetting.get('activate_email',False): + # 发出邮件 + try: + sendmail(subject=u'{0} Critical'.format('_'.join(self.connected_gw_names)), msgcontent=content) + except Exception as ex: + print(u'vtEngine.writeCritical sendmail Exception:{}'.format(str(ex)), file=sys.stderr) + print(u'{}'.format(traceback.format_exc()), file=sys.stderr) + # 发出微信 + if globalSetting.get('activate_wx_ft',False): + try: + from huafu.util.util_wx_ft import sendWxMsg + sendWxMsg(text=content) + except Exception as ex: + print(u'vtEngine.writeCritical sendWxMsg Exception:{}'.format(str(ex)), file=sys.stderr) + print(u'{}'.format(traceback.format_exc()), file=sys.stderr) # # ---------------------------------------------------------------------- def dbConnect(self): @@ -515,7 +666,6 @@ class MainEngine(object): self.writeError(text.DATABASE_CONNECTING_FAILED) self.db_has_connected = False - # ---------------------------------------------------------------------- def dbInsert(self, dbName, collectionName, d): """向MongoDB中插入数据,d是具体数据""" @@ -648,7 +798,7 @@ class MainEngine(object): self.writeError(u'dbQueryBySort exception:{}'.format(str(ex))) return [] - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def dbUpdate(self, dbName, collectionName, d, flt, upsert=False): """向MongoDB中更新数据,d是具体数据,flt是过滤条件,upsert代表若无是否要插入""" try: @@ -701,7 +851,7 @@ class MainEngine(object): except Exception as ex: self.writeError(u'dbDelete exception:{}'.format(str(ex))) - #---------------------------------------------------------------------- + # ---------------------------------------------------------------------- def dbLogging(self, event): """向MongoDB中插入日志""" log = event.dict_['data'] @@ -802,7 +952,13 @@ class DataEngine(object): # 保存合约详细信息的字典 self.contractDict = {} - + + # 本地自定义的合约详细配置字典 + self.custom_contract_setting = {} + + # 合约与自定义套利合约映射表 + self.contract_spd_mapping = {} + # 保存委托数据的字典 self.orderDict = {} @@ -836,10 +992,14 @@ class DataEngine(object): def getContract(self, vtSymbol): """查询合约对象""" try: - return self.contractDict[vtSymbol] - except KeyError: + if vtSymbol in self.contractDict.keys(): + return self.contractDict[vtSymbol] + return None - + except Exception as ex: + print(str(ex),file=sys.stderr) + return None + # ---------------------------------------------------------------------- def getAllContracts(self): """查询所有合约对象(返回列表)""" @@ -863,7 +1023,28 @@ class DataEngine(object): for key, value in d.items(): self.contractDict[key] = value f.close() - + + c = Custom_Contract() + self.custom_contract_setting = c.get_config() + d = c.get_contracts() + if len(d)>0: + print(u'更新本地定制合约') + self.contractDict.update(d) + + # 将leg1,leg2合约对应的自定义spd合约映射 + for spd_name in self.custom_contract_setting.keys(): + setting = self.custom_contract_setting.get(spd_name) + leg1_symbol = setting.get('leg1_symbol') + leg2_symbol = setting.get('leg2_symbol') + + for symbol in [leg1_symbol,leg2_symbol]: + spd_mapping_list = self.contract_spd_mapping.get(symbol,[]) + + # 更新映射 + if spd_name not in spd_mapping_list: + spd_mapping_list.append(spd_name) + self.contract_spd_mapping.update({symbol:spd_mapping_list}) + # ---------------------------------------------------------------------- def updateOrder(self, event): """更新委托数据""" @@ -938,8 +1119,7 @@ class DataEngine(object): def updatePosition(self,event): """更新持仓信息""" - # 在获取更新持仓信息时,自动订阅这个symbol - # 目的:1、 + # 1、在获取更新持仓信息时,自动订阅这个symbol position = event.dict_['data'] symbol = position.symbol @@ -974,3 +1154,59 @@ class DataEngine(object): self.mainEngine.writeLog(u'自动订阅合约{0}'.format(symbol)) +class Custom_Contract(object): + """ + 定制合约 + # 适用于初始化系统时,补充到本地合约信息文件中 contracts.vt + # 适用于CTP网关,加载自定义的套利合约,做内部行情撮合 + """ + # 运行本地目录下,定制合约的配置文件(dict) + file_name = 'Custom_Contracts.json' + custom_config_file = getJsonPath(file_name, __file__) + + def __init__(self): + """构造函数""" + + self.setting = {} # 所有设置 + + try: + # 配置文件是否存在 + if not os.path.exists(self.custom_config_file): + return + + # 加载配置文件,兼容中文说明 + with open(self.custom_config_file,'r',encoding='UTF-8') as f: + # 解析json文件 + print(u'从{}文件加载定制合约'.format(self.custom_config_file)) + self.setting = json.load(f) + + except IOError: + print('读取{} 出错'.format(self.custom_config_file),file=sys.stderr) + + def get_config(self): + """获取配置""" + return self.setting + + def get_contracts(self): + """获取所有合约信息""" + d = {} + for k,v in self.setting.items(): + contract = VtContractData() + contract.gatewayName = v.get('gateway_name',None) + if contract.gatewayName is None and globalSetting.get('gateway_name',None) is not None: + contract.gatewayName = globalSetting.get('gateway_name') + contract.symbol = k + contract.exchange = v.get('exchange',None) + contract.vtSymbol = contract.symbol + contract.name = v.get('name',contract.symbol) + contract.size = v.get('size',100) + + contract.priceTick = v.get('minDiff',0.01) # 最小跳动 + contract.strikePrice = v.get('strike_price',None) + contract.underlyingSymbol = v.get('underlying_symbol') + contract.longMarginRatio = v.get('margin_rate',0.1) + contract.shortMarginRatio = v.get('margin_rate',0.1) + contract.productClass = v.get('product_class',None) + d[contract.vtSymbol] = contract + + return d diff --git a/vnpy/trader/vtFunction.py b/vnpy/trader/vtFunction.py index 2b4d0f93..f1933c5b 100644 --- a/vnpy/trader/vtFunction.py +++ b/vnpy/trader/vtFunction.py @@ -7,14 +7,58 @@ import os,sys import decimal import json -from datetime import datetime +from datetime import datetime,timedelta import importlib import re +from functools import lru_cache MAX_NUMBER = 10000000000000 MAX_DECIMAL = 8 +def import_module_by_str(import_module_name): + """ + 动态导入模块/函数 + :param import_module_name: + :return: + """ + import traceback + from importlib import import_module, reload + + # 参数检查 + if len(import_module_name) == 0: + print('import_module_by_str parameter error,return None') + return None + + print('trying to import {}'.format(import_module_name)) + try: + import_name = str(import_module_name).replace(':', '.') + modules = import_name.split('.') + if len(modules) == 1: + mod = import_module(modules[0]) + return mod + else: + loaded_modules = '.'.join(modules[0:-1]) + print('import {}'.format(loaded_modules)) + mod = import_module(loaded_modules) + + comp = modules[-1] + if not hasattr(mod, comp): + loaded_modules = '.'.join([loaded_modules,comp]) + print('realod {}'.format(loaded_modules)) + mod = reload(loaded_modules) + else: + print('from {} import {}'.format(loaded_modules,comp)) + mod = getattr(mod, comp) + return mod + + except Exception as ex: + print('import {} fail,{},{}'.format(import_module_name,str(ex),traceback.format_exc())) + + return None + + + def floatToStr(float_str): """格式化显示浮点字符串,去除后面的0""" if '.' in float_str: @@ -30,6 +74,45 @@ def floatToStr(float_str): else: return float_str +# ---------------------------------------------------------------------- +def roundToPriceTick(priceTick, price): + """取整价格到合约最小价格变动""" + if not priceTick: + return price + + if price > 0: + # 根据最小跳动取整 + newPrice = price - price % priceTick + else: + # 兼容套利品种的负数价格 + newPrice = round(price / priceTick, 0) * priceTick + + # 数字货币,对浮点的长度有要求,需要砍除多余 + if isinstance(priceTick,float): + price_exponent = decimal.Decimal(str(newPrice)) + tick_exponent = decimal.Decimal(str(priceTick)) + if abs(price_exponent.as_tuple().exponent) > abs(tick_exponent.as_tuple().exponent): + newPrice = round(newPrice, ndigits=abs(tick_exponent.as_tuple().exponent)) + newPrice = float(str(newPrice)) + + return newPrice + +def roundToVolumeTick(volumeTick,volume): + if volumeTick == 0: + return volume + # 取整 + newVolume = volume - volume % volumeTick + + if isinstance(volumeTick,float): + v_exponent = decimal.Decimal(str(newVolume)) + vt_exponent = decimal.Decimal(str(volumeTick)) + if abs(v_exponent.as_tuple().exponent) > abs(vt_exponent.as_tuple().exponent): + newVolume = round(newVolume, ndigits=abs(vt_exponent.as_tuple().exponent)) + newVolume = float(str(newVolume)) + + return newVolume + +@lru_cache() def getShortSymbol(symbol): """取得合约的短号""" # 套利合约 @@ -55,15 +138,25 @@ def getShortSymbol(symbol): return shortSymbol.group(1) +@lru_cache() def getFullSymbol(symbol): - """获取全路径得合约名称""" + """ + 获取全路径得合约名称 + """ + if symbol.endswith('SPD'): + return symbol + short_symbol = getShortSymbol(symbol) if short_symbol == symbol: return symbol symbol_month = symbol.replace(short_symbol, '') if len(symbol_month) == 3: - return '{0}1{1}'.format(short_symbol, symbol_month) + if symbol_month[0] == '0': + # 支持2020年合约 + return '{0}2{1}'.format(short_symbol, symbol_month) + else: + return '{0}1{1}'.format(short_symbol, symbol_month) else: return symbol @@ -116,6 +209,14 @@ def safeUnicode(value): return value + +def get_tdx_market_code(code): + # 获取通达信股票的market code + code = str(code) + if code[0] in ['5', '6', '9'] or code[:3] in ["009", "126", "110", "201", "202", "203", "204"]: + return 1 + return 0 + #---------------------------------------------------------------------- def loadMongoSetting(): """载入MongoDB数据库的配置""" @@ -137,6 +238,31 @@ def todayDate(): """获取当前本机电脑时间的日期""" return datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) + +def getTradingDate(dt=None): + """ + 根据输入的时间,返回交易日的日期 + :param dt: + :return: + """ + tradingDay = '' + if dt is None: + dt = datetime.now() + + if dt.hour >= 21: + if dt.isoweekday() == 5: + # 星期五=》星期一 + return (dt + timedelta(days=3)).strftime('%Y-%m-%d') + else: + # 第二天 + return (dt + timedelta(days=1)).strftime('%Y-%m-%d') + elif dt.hour < 8 and dt.isoweekday() == 6: + # 星期六=>星期一 + return (dt + timedelta(days=2)).strftime('%Y-%m-%d') + else: + return dt.strftime('%Y-%m-%d') + + # 图标路径 iconPathDict = {} @@ -163,6 +289,12 @@ def getTempPath(name): path = os.path.join(tempPath, name) return path +def get_data_path(): + """获取存放数据文件的路径""" + data_path = os.path.join(os.getcwd(),'data') + if not os.path.exists(data_path): + os.makedirs(data_path) + return data_path # JSON配置文件路径 jsonPathDict = {} @@ -188,14 +320,6 @@ def getJsonPath(name, moduleFile): return moduleJsonPath -# ----------------------------- 扩展的功能 --------- -try: - import openpyxl - from openpyxl.utils.dataframe import dataframe_to_rows - from openpyxl.drawing.image import Image -except: - print(u'can not import openpyxl',file=sys.stderr) - def save_df_to_excel(file_name, sheet_name, df): """ 保存dataframe到execl @@ -207,6 +331,14 @@ def save_df_to_excel(file_name, sheet_name, df): if file_name is None or sheet_name is None or df is None: return False + # ----------------------------- 扩展的功能 --------- + try: + import openpyxl + from openpyxl.utils.dataframe import dataframe_to_rows + from openpyxl.drawing.image import Image + except: + print(u'can not import openpyxl', file=sys.stderr) + if 'openpyxl' not in sys.modules: print(u'can not import openpyxl', file=sys.stderr) return False @@ -254,6 +386,14 @@ def save_text_to_excel(file_name, sheet_name, text): if file_name is None or len(sheet_name)==0 or len(text) == 0 : return False + # ----------------------------- 扩展的功能 --------- + try: + import openpyxl + from openpyxl.utils.dataframe import dataframe_to_rows + from openpyxl.drawing.image import Image + except: + print(u'can not import openpyxl', file=sys.stderr) + if 'openpyxl' not in sys.modules: return False @@ -300,6 +440,13 @@ def save_images_to_excel(file_name, sheet_name, image_names): """ if file_name is None or len(sheet_name) == 0 or len(image_names) == 0: return False + # ----------------------------- 扩展的功能 --------- + try: + import openpyxl + from openpyxl.utils.dataframe import dataframe_to_rows + from openpyxl.drawing.image import Image + except: + print(u'can not import openpyxl', file=sys.stderr) if 'openpyxl' not in sys.modules: return False @@ -349,4 +496,57 @@ def save_images_to_excel(file_name, sheet_name, image_names): except Exception as ex: import traceback print(u'save_images_to_excel exception:{}'.format(str(ex)), traceback.format_exc(),file=sys.stderr) - return False \ No newline at end of file + return False + + +def display_dual_axis(df, columns1, columns2=[], invert_yaxis1=False, invert_yaxis2=False, file_name=None, sheet_name=None, + image_name=None): + """ + 显示(保存)双Y轴的走势图 + :param df: DataFrame + :param columns1: y1轴 + :param columns2: Y2轴 + :param invert_yaxis1: Y1 轴反转 + :param invert_yaxis2: Y2 轴翻转 + :param file_name: 保存的excel 文件名称 + :param sheet_name: excel 的sheet + :param image_name: 保存的image 文件名 + :return: + """ + + import matplotlib + import matplotlib.pyplot as plt + matplotlib.rcParams['figure.figsize'] = (20.0, 10.0) + + df1 = df[columns1] + df1.index = list(range(len(df))) + fig, ax1 = plt.subplots() + if invert_yaxis1: + ax1.invert_yaxis() + ax1.plot(df1) + + if len(columns2) > 0: + df2 = df[columns2] + df2.index = list(range(len(df))) + ax2 = ax1.twinx() + if invert_yaxis2: + ax2.invert_yaxis() + ax2.plot(df2) + + # 修改x轴得label为时间 + xt = ax1.get_xticks() + xt2 = [df.index[int(i)] for i in xt[1:-2]] + xt2.insert(0, '') + xt2.append('') + ax1.set_xticklabels(xt2) + + # 是否保存图片到文件 + if image_name is not None: + fig = plt.gcf() + fig.savefig(image_name, bbox_inches='tight') + + # 插入图片到指定的excel文件sheet中并保存excel + if file_name is not None and sheet_name is not None: + save_images_to_excel(file_name, sheet_name, [image_name]) + else: + plt.show() \ No newline at end of file