修改持仓计算,逐笔盈亏

This commit is contained in:
msincenselee 2019-06-03 11:10:31 +08:00
parent 674c9f04fe
commit 4b9b03aa3b
4 changed files with 1589 additions and 354 deletions

File diff suppressed because it is too large Load Diff

View File

@ -167,6 +167,7 @@ class CtpseGateway(VtGateway):
mdAddress = str(setting['mdAddress'])
self.debug_tick = setting.get('debug_tick',False)
self.debug = setting.get('debug',False)
# 如果json文件提供了验证码
if 'authCode' in setting:
@ -494,8 +495,7 @@ class CtpMdApi(MdApi):
def onHeartBeatWarning(self, n):
"""心跳报警"""
# 因为API的心跳报警比较常被触发且与API工作关系不大因此选择忽略
if getattr(self.gateway,'debug',False):
print('onHeartBeatWarning')
pass
#----------------------------------------------------------------------
def onRspError(self, error, n, last):
@ -554,15 +554,12 @@ class CtpMdApi(MdApi):
def onRspSubMarketData(self, data, error, n, last):
"""订阅合约回报"""
# 通常不在乎订阅错误,选择忽略
if getattr(self.gateway, 'debug', False):
print('onRspSubMarketData')
#----------------------------------------------------------------------
def onRspUnSubMarketData(self, data, error, n, last):
"""退订合约回报"""
# 同上
if getattr(self.gateway, 'debug', False):
print('onRspUnSubMarketData')
pass
#----------------------------------------------------------------------
def onRtnDepthMarketData(self, data):
@ -572,8 +569,6 @@ class CtpMdApi(MdApi):
# self.writeLog(u'忽略成交量为0的无效单合约tick数据:')
# self.writeLog(data)
# return
if getattr(self.gateway, 'debug', False):
print('onRtnDepthMarketData')
if not self.connectionStatus:
self.connectionStatus = True
@ -683,29 +678,25 @@ class CtpMdApi(MdApi):
self.gateway.onCustomerTick(tick)
#----------------------------------------------------------------------
# ----------------------------------------------------------------------
def onRspSubForQuoteRsp(self, data, error, n, last):
"""订阅期权询价"""
if getattr(self.gateway, 'debug', False):
print('onRspSubForQuoteRsp')
#----------------------------------------------------------------------
def onRspUnSubForQuoteRsp(self, data, error, n, last):
"""退订期权询价"""
if getattr(self.gateway, 'debug', False):
print('onRspUnSubForQuoteRsp')
#----------------------------------------------------------------------
def onRtnForQuoteRsp(self, data):
"""期权询价推送"""
if getattr(self.gateway, 'debug', False):
print('onRtnForQuoteRsp')
pass
#----------------------------------------------------------------------
def onRspUnSubForQuoteRsp(self, data, error, n, last):
"""退订期权询价"""
pass
# ----------------------------------------------------------------------
def onRtnForQuoteRsp(self, data):
"""期权询价推送"""
pass
# ----------------------------------------------------------------------
def connect(self, userID, password, brokerID, address):
"""初始化连接"""
if getattr(self.gateway, 'debug', False):
print('connect')
self.userID = userID # 账号
self.password = password # 密码
self.brokerID = brokerID # 经纪商代码
@ -718,7 +709,7 @@ class CtpMdApi(MdApi):
if not os.path.exists(path):
os.makedirs(path)
self.createFtdcMdApi(path)
self.writeLog(u'注册行情服务器地址:{}'.format(self.address))
# 注册服务器地址
self.registerFront(self.address)
@ -732,37 +723,43 @@ class CtpMdApi(MdApi):
if not self.loginStatus:
self.login()
#----------------------------------------------------------------------
# ----------------------------------------------------------------------
def subscribe(self, subscribeReq):
"""订阅合约"""
# 这里的设计是,如果尚未登录就调用了订阅方法
# 则先保存订阅请求,登录完成后会自动订阅
if getattr(self.gateway, 'debug', False):
print('subscribe')
if self.connectionStatus and not self.loginStatus:
self.login()
# 订阅传统合约
self.subscribeMarketData(str(subscribeReq.symbol))
self.writeLog(u'订阅合约:{0}'.format(str(subscribeReq.symbol)))
self.subscribedSymbols.add(subscribeReq)
#----------------------------------------------------------------------
# ----------------------------------------------------------------------
def login(self):
"""登录"""
# 如果填入了用户名密码等,则登录
if self.userID and self.password and self.brokerID:
self.writeLog(u'登入行情服务器')
req = {}
req['UserID'] = self.userID
req['Password'] = self.password
req['BrokerID'] = self.brokerID
self.reqID += 1
self.reqUserLogin(req, self.reqID)
self.reqUserLogin(req, self.reqID)
else:
self.writeLog(u'未配置用户/密码,不登录行情服务器')
#----------------------------------------------------------------------
# ----------------------------------------------------------------------
def close(self):
"""关闭"""
self.exit()
#----------------------------------------------------------------------
# ----------------------------------------------------------------------
def writeLog(self, content):
"""发出日志"""
log = VtLogData()
@ -809,8 +806,6 @@ class CtpTdApi(TdApi):
#----------------------------------------------------------------------
def onFrontConnected(self):
"""服务器连接"""
if getattr(self.gateway, 'debug', False):
print('onFrontConnected')
self.connectionStatus = True
self.writeLog(text.TRADING_SERVER_CONNECTED)
@ -822,8 +817,6 @@ class CtpTdApi(TdApi):
#----------------------------------------------------------------------
def onFrontDisconnected(self, n):
"""服务器断开"""
if getattr(self.gateway, 'debug', False):
print('onFrontDisconnected')
self.connectionStatus = False
self.loginStatus = False
self.gateway.tdConnected = False
@ -833,15 +826,10 @@ class CtpTdApi(TdApi):
#----------------------------------------------------------------------
def onHeartBeatWarning(self, n):
""""""
if getattr(self.gateway, 'debug', False):
print('onHeartBeatWarning')
#----------------------------------------------------------------------
def onRspAuthenticate(self, data, error, n, last):
"""验证客户端回报"""
if getattr(self.gateway, 'debug', False):
print('onRspAuthenticate')
if error['ErrorID'] == 0:
self.authStatus = True
self.writeLog(text.TRADING_SERVER_AUTHENTICATED)
@ -852,8 +840,6 @@ class CtpTdApi(TdApi):
#----------------------------------------------------------------------
def onRspUserLogin(self, data, error, n, last):
"""登陆回报"""
if getattr(self.gateway, 'debug', False):
print('onRspUserLogin')
# 如果登录成功,推送日志信息
if error['ErrorID'] == 0:
self.tradingDay = str(data['TradingDay'])
@ -887,16 +873,12 @@ class CtpTdApi(TdApi):
def resentReqQryInstrument(self):
# 查询合约代码
if getattr(self.gateway, 'debug', False):
print('resentReqQryInstrument')
self.reqID += 1
self.reqQryInstrument({}, self.reqID)
#----------------------------------------------------------------------
def onRspUserLogout(self, data, error, n, last):
"""登出回报"""
if getattr(self.gateway, 'debug', False):
print('onRspUserLogout')
# 如果登出成功,推送日志信息
if error['ErrorID'] == 0:
self.loginStatus = False
@ -915,20 +897,16 @@ class CtpTdApi(TdApi):
#----------------------------------------------------------------------
def onRspUserPasswordUpdate(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspUserPasswordUpdate')
pass
#----------------------------------------------------------------------
def onRspTradingAccountPasswordUpdate(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspTradingAccountPasswordUpdate')
pass
#----------------------------------------------------------------------
def onRspOrderInsert(self, data, error, n, last):
"""发单错误(柜台)"""
if getattr(self.gateway, 'debug', False):
print('onRspOrderInsert')
# 推送委托信息
order = VtOrderData()
order.gatewayName = self.gatewayName
@ -956,20 +934,16 @@ class CtpTdApi(TdApi):
#----------------------------------------------------------------------
def onRspParkedOrderInsert(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspParkedOrderInsert')
pass
#----------------------------------------------------------------------
def onRspParkedOrderAction(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspParkedOrderAction')
pass
#----------------------------------------------------------------------
def onRspOrderAction(self, data, error, n, last):
"""撤单错误(柜台)"""
if getattr(self.gateway, 'debug', False):
print('onRspOrderAction')
try:
symbol = data['InstrumentID']
except KeyError:
@ -985,89 +959,77 @@ class CtpTdApi(TdApi):
#----------------------------------------------------------------------
def onRspQueryMaxOrderVolume(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQueryMaxOrderVolume')
pass
#----------------------------------------------------------------------
def onRspSettlementInfoConfirm(self, data, error, n, last):
"""确认结算信息回报"""
if getattr(self.gateway, 'debug', False):
print('onRspSettlementInfoConfirm')
self.writeLog(text.SETTLEMENT_INFO_CONFIRMED)
#----------------------------------------------------------------------
def onRspRemoveParkedOrder(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspSettlementInfoConfirm')
pass
#----------------------------------------------------------------------
def onRspRemoveParkedOrderAction(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspRemoveParkedOrderAction')
pass
#----------------------------------------------------------------------
def onRspExecOrderInsert(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspExecOrderInsert')
pass
#----------------------------------------------------------------------
def onRspExecOrderAction(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspExecOrderAction')
pass
#----------------------------------------------------------------------
def onRspForQuoteInsert(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspForQuoteInsert')
pass
#----------------------------------------------------------------------
def onRspQuoteInsert(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQuoteInsert')
pass
#----------------------------------------------------------------------
def onRspQuoteAction(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQuoteAction')
pass
#----------------------------------------------------------------------
def onRspLockInsert(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspLockInsert')
pass
#----------------------------------------------------------------------
def onRspCombActionInsert(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspCombActionInsert')
pass
#----------------------------------------------------------------------
def onRspQryOrder(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspCombActionInsert')
pass
#----------------------------------------------------------------------
def onRspQryTrade(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryTrade')
pass
#----------------------------------------------------------------------
def onRspQryInvestorPosition(self, data, error, n, last):
"""持仓查询回报"""
if getattr(self.gateway, 'debug', False):
if self.gateway.debug:
print('onRspQryInvestorPosition')
print(u'data:{}'.format(data))
print(u'error:{}'.format(error))
print('n:{},last:{}'.format(n,last))
if not data['InstrumentID']:
return
@ -1075,64 +1037,82 @@ class CtpTdApi(TdApi):
if not self.gateway.tdConnected:
self.gateway.tdConnected = True
# 获取持仓缓存对象
posName = '.'.join([data['InstrumentID'], data['PosiDirection']])
if posName in self.posDict:
pos = self.posDict[posName]
else:
pos = VtPositionData()
self.posDict[posName] = pos
pos.gatewayName = self.gatewayName
pos.symbol = data['InstrumentID']
pos.vtSymbol = pos.symbol
pos.direction = posiDirectionMapReverse.get(data['PosiDirection'], '')
pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction, pos.gatewayName])
try:
# 获取持仓缓存对象
posName = '.'.join([data['InstrumentID'], data['PosiDirection']])
if posName in self.posDict:
pos = self.posDict[posName]
else:
pos = VtPositionData()
self.posDict[posName] = pos
exchange = self.symbolExchangeDict.get(pos.symbol, EXCHANGE_UNKNOWN)
pos.gatewayName = self.gatewayName
pos.symbol = data['InstrumentID']
pos.vtSymbol = pos.symbol
pos.direction = posiDirectionMapReverse.get(data['PosiDirection'], '')
pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction, pos.gatewayName])
# 针对上期所持仓的今昨分条返回(有昨仓、无今仓),读取昨仓数据
if exchange == EXCHANGE_SHFE:
if data['YdPosition'] and not data['TodayPosition']:
pos.ydPosition = data['Position']
# 否则基于总持仓和今持仓来计算昨仓数据
else:
pos.ydPosition = data['Position'] - data['TodayPosition']
exchange = self.symbolExchangeDict.get(pos.symbol, EXCHANGE_UNKNOWN)
# 计算成本
if pos.symbol not in self.symbolSizeDict:
return
size = self.symbolSizeDict[pos.symbol]
cost = pos.price * pos.position * size
# 针对上期所持仓的今昨分条返回(有昨仓、无今仓),读取昨仓数据
if exchange == EXCHANGE_SHFE:
if data['YdPosition'] and not data['TodayPosition']:
pos.ydPosition = data['Position']
# 否则基于总持仓和今持仓来计算昨仓数据
else:
pos.ydPosition = data['Position'] - data['TodayPosition']
# 汇总总仓
pos.position += data['Position']
pos.positionProfit += data['PositionProfit']
# 计算成本
if pos.symbol not in self.symbolSizeDict:
return
size = self.symbolSizeDict[pos.symbol]
cost = pos.price * pos.position * size
# 计算持仓均价
if pos.position and size:
pos.price = (cost + data['PositionCost']) / (pos.position * size)
# 汇总总仓
pos.position += data['Position']
# 读取冻结
if pos.direction is DIRECTION_LONG:
pos.frozen += data['LongFrozen']
else:
pos.frozen += data['ShortFrozen']
# 计算持仓均价
if pos.position and size:
#pos.price = (cost + data['PositionCost']) / (pos.position * size)
pos.price = (cost + data['OpenCost']) / (pos.position * size)
# 上一交易日结算价
pre_settlement_price = data['PreSettlementPrice']
# 开仓均价
open_cost_price = data['OpenCost'] / (pos.position * size)
# 逐笔盈亏 = (上一交易日结算价 - 开仓价)* 持仓数量 * 杠杆 + 当日持仓收益
if pos.direction == DIRECTION_LONG:
pre_profit = (pre_settlement_price - open_cost_price) * (pos.position * size)
else:
pre_profit = (open_cost_price - pre_settlement_price) * (pos.position * size)
pos.positionProfit = pos.positionProfit + pre_profit + data['PositionProfit']
# 读取冻结
if pos.direction is DIRECTION_LONG:
pos.frozen += data['LongFrozen']
else:
pos.frozen += data['ShortFrozen']
# 查询回报结束
if last:
# 遍历推送
if self.gateway.debug:
print(u'最后推送')
for pos in list(self.posDict.values()):
self.gateway.onPosition(pos)
# 清空缓存
self.posDict.clear()
except Exception as ex:
self.gateway.writeError('onRspQryInvestorPosition exception:{}'.format(str(ex)))
self.gateway.writeError('traceL{}'.format(traceback.format_exc()))
# 查询回报结束
if last:
# 遍历推送
for pos in list(self.posDict.values()):
self.gateway.onPosition(pos)
# 清空缓存
self.posDict.clear()
#----------------------------------------------------------------------
def onRspQryTradingAccount(self, data, error, n, last):
"""资金账户查询回报"""
if getattr(self.gateway, 'debug', False):
print('onRspQryTradingAccount')
self.gateway.mdConnected = True
account = VtAccountData()
@ -1162,14 +1142,12 @@ class CtpTdApi(TdApi):
#----------------------------------------------------------------------
def onRspQryInvestor(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryInvestor')
pass
#----------------------------------------------------------------------
def onRspQryTradingCode(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryTradingCode')
pass
#----------------------------------------------------------------------
def onRspQryInstrumentMarginRate(self, data, error, n, last):
@ -1181,32 +1159,26 @@ class CtpTdApi(TdApi):
:param last:
:return:
"""
if getattr(self.gateway, 'debug', False):
print('onRspQryInstrumentMarginRate')
pass
#----------------------------------------------------------------------
def onRspQryInstrumentCommissionRate(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryInstrumentCommissionRate')
pass
#----------------------------------------------------------------------
def onRspQryExchange(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryExchange')
pass
#----------------------------------------------------------------------
def onRspQryProduct(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryProduct')
pass
#----------------------------------------------------------------------
def onRspQryInstrument(self, data, error, n, last):
"""合约查询回报"""
if getattr(self.gateway, 'debug', False):
print('onRspQryInstrument')
self.gateway.mdConnected = True
contract = VtContractData()
@ -1260,74 +1232,61 @@ class CtpTdApi(TdApi):
#----------------------------------------------------------------------
def onRspQryDepthMarketData(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryDepthMarketData')
pass
#----------------------------------------------------------------------
def onRspQrySettlementInfo(self, data, error, n, last):
"""查询结算信息回报"""
if getattr(self.gateway, 'debug', False):
print('onRspQryDepthMarketData')
pass
#----------------------------------------------------------------------
def onRspQryTransferBank(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryTransferBank')
pass
#----------------------------------------------------------------------
def onRspQryInvestorPositionDetail(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryInvestorPositionDetail')
pass
#----------------------------------------------------------------------
def onRspQryNotice(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryInvestorPositionDetail')
pass
#----------------------------------------------------------------------
def onRspQrySettlementInfoConfirm(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQrySettlementInfoConfirm')
pass
#----------------------------------------------------------------------
def onRspQryInvestorPositionCombineDetail(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryInvestorPositionCombineDetail')
pass
#----------------------------------------------------------------------
def onRspQryCFMMCTradingAccountKey(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryCFMMCTradingAccountKey')
pass
#----------------------------------------------------------------------
def onRspQryEWarrantOffset(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryEWarrantOffset')
pass
#----------------------------------------------------------------------
def onRspQryInvestorProductGroupMargin(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryInvestorProductGroupMargin')
pass
#----------------------------------------------------------------------
def onRspQryExchangeMarginRate(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryExchangeMarginRate')
pass
#----------------------------------------------------------------------
def onRspQryExchangeMarginRateAdjust(self, data, error, n, last):
""""""
if getattr(self.gateway, 'debug', False):
print('onRspQryExchangeMarginRateAdjust')
pass
#----------------------------------------------------------------------
def onRspQryExchangeRate(self, data, error, n, last):

View File

@ -1,6 +1,8 @@
# encoding: UTF-8
print(u'启动load vtEngine.py')
# 20190318 增加自定义套利合约Customer_Contract
print('load vtEngine.py')
import shelve
from collections import OrderedDict
@ -13,14 +15,18 @@ from pymongo.errors import ConnectionFailure,AutoReconnect
from vnpy.trader.vtEvent import Event as vn_event
from vnpy.trader.language import text
#from vnpy.trader.app.ctaStrategy.ctaEngine import CtaEngine
#from vnpy.trader.app.dataRecorder.drEngine import DrEngine
#from vnpy.trader.app.riskManager.rmEngine import RmEngine
from vnpy.trader.vtFunction import loadMongoSetting, getTempPath
from vnpy.trader.vtFunction import loadMongoSetting, getTempPath,getFullSymbol,getShortSymbol,getJsonPath
from vnpy.trader.vtGateway import *
from vnpy.trader.app import (ctaStrategy,cmaStrategy, riskManager)
from vnpy.trader.app import (ctaStrategy, riskManager)
from vnpy.trader.setup_logger import setup_logger
from vnpy.trader.vtGlobal import globalSetting
import traceback
from datetime import datetime, timedelta, time, date
from vnpy.trader.vtConstant import (DIRECTION_LONG, DIRECTION_SHORT,
OFFSET_OPEN, OFFSET_CLOSE, OFFSET_CLOSETODAY,
STATUS_ALLTRADED, STATUS_CANCELLED, STATUS_REJECTED,STATUS_UNKNOWN)
import psutil
try:
@ -67,6 +73,7 @@ class MainEngine(object):
self.ctaEngine = None # CtaEngine(self, self.eventEngine) # cta策略运行模块
self.drEngine = None # DrEngine(self, self.eventEngine) # 数据记录模块
self.rmEngine = None # RmEngine(self, self.eventEngine) # 风险管理模块
self.algoEngine = None # 算法交易模块
self.cmaEngine = None # 跨市场套利引擎
self.connected_gw_names = []
@ -76,6 +83,9 @@ class MainEngine(object):
self.createLogger()
self.spd_id = 1 # 自定义套利的委托编号
self.algo_order_dict = {} # 记录算法交易的委托编号,便于通过算法引擎撤单
# ----------------------------------------------------------------------
def addGateway(self, gatewayModule,gateway_name=EMPTY_STRING):
"""添加底层接口"""
@ -119,7 +129,9 @@ class MainEngine(object):
self.ctaEngine = self.appDict[appName]
elif appName == riskManager.appName:
self.rmEngine = self.appDict[appName]
elif appName == cmaStrategy.appName:
elif appName == 'AlgoTrading':
self.algoEngine = self.appDict[appName]
elif appName == 'CrossMarketArbitrage':
self.cmaEngine = self.appDict[appName]
# 保存应用信息
@ -198,6 +210,9 @@ class MainEngine(object):
def subscribe(self, subscribeReq, gatewayName):
"""订阅特定接口的行情"""
# 处理没有输入gatewayName的情况
if len(subscribeReq.symbol) == 0:
return
if gatewayName is None or len(gatewayName) == 0:
if len(self.connected_gw_names) == 0:
self.writeError(u'vtEngine.subscribe, no connected gateway')
@ -214,8 +229,11 @@ class MainEngine(object):
self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName))
# ----------------------------------------------------------------------
def sendOrder(self, orderReq, gatewayName):
"""对特定接口发单"""
def sendOrder(self, orderReq, gatewayName, strategyName=None):
"""
对特定接口发单
strategyName: ctaEngine中发单的策略实例名称
"""
# 如果风控检查失败则不发单
if self.rmEngine and not self.rmEngine.checkRisk(orderReq):
self.writeCritical(u'风控检查不通过,gw:{},{} {} {} p:{} v:{}'.format(gatewayName, orderReq.direction, orderReq.offset, orderReq.symbol, orderReq.price, orderReq.volume))
@ -229,21 +247,145 @@ class MainEngine(object):
orderReq.symbol, orderReq.price, orderReq.volume))
return ''
# 判断是否包含gatewayName
if gatewayName is None or len(gatewayName)==0:
if len(self.connected_gw_names) == 1:
gatewayName = self.connected_gw_names[0]
else:
self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName))
return ''
else:
if gatewayName not in self.connected_gw_names:
self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName))
return ''
# 如果合约在自定义清单中并且包含SPD则使用算法交易, 而不是CtpGateway发单
if orderReq.symbol.endswith('SPD') and orderReq.symbol in self.dataEngine.custom_contract_setting:
return self.sendAlgoOrder(orderReq, gatewayName, strategyName)
if gatewayName in self.gatewayDict:
gateway = self.gatewayDict[gatewayName]
return gateway.sendOrder(orderReq)
else:
self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName))
def sendAlgoOrder(self,orderReq, gatewayName, strategyName=None):
"""发送算法交易指令"""
self.writeLog(u'创建算法交易,gatewayName:{},strategyName:{},symbol:{},price:{},volume:{}'
.format(gatewayName, strategyName,orderReq.vtSymbol,orderReq.price,orderReq.volume))
# 创建一个Order事件
order = VtOrderData()
order.vtSymbol = orderReq.vtSymbol
order.symbol = orderReq.symbol
order.exchange = orderReq.exchange
order.gatewayName = gatewayName
order.direction = orderReq.direction
order.offset = orderReq.offset
order.price = orderReq.price
order.totalVolume = orderReq.volume
order.tradedVolume = 0
order.orderTime = datetime.now().strftime('%H:%M:%S.%f')
order.orderID = 'spd_{}'.format(self.spd_id)
self.spd_id += 1
order.vtOrderID = gatewayName+'.'+order.orderID
#如果算法引擎未启动,发出拒单事件
if self.algoEngine is None:
try:
self.writeLog(u'算法引擎未启动启动ing')
from vnpy.trader.app import algoTrading
self.addApp(algoTrading)
except Exception as ex:
self.writeError(u'算法引擎未加载,不能创建算法交易')
order.cancelTime = datetime.now().strftime('%H:%M:%S.%f')
order.status = STATUS_REJECTED
event1 = Event(type_=EVENT_ORDER)
event1.dict_['data'] = order
self.eventEngine.put(event1)
return ''
# 创建算法实例,由算法引擎启动
tradeCommand = ''
if orderReq.direction == DIRECTION_LONG and orderReq.offset == OFFSET_OPEN:
tradeCommand = 'Buy'
elif orderReq.direction == DIRECTION_SHORT and orderReq.offset == OFFSET_OPEN:
tradeCommand = 'Short'
elif orderReq.direction == DIRECTION_SHORT and orderReq.offset in [OFFSET_CLOSE, OFFSET_CLOSETODAY]:
tradeCommand = 'Sell'
elif orderReq.direction == DIRECTION_LONG and orderReq.offset in [OFFSET_CLOSE, OFFSET_CLOSETODAY]:
tradeCommand = 'Cover'
contract_setting = self.dataEngine.custom_contract_setting.get(orderReq.symbol,{})
algo = {
'templateName': u'SpreadTrading 套利',
'order_vtSymbol': orderReq.symbol,
'order_command': tradeCommand,
'order_price': orderReq.price,
'order_volume': orderReq.volume,
'timer_interval': 120,
'strategy_name': strategyName
}
algo.update(contract_setting)
# 算法引擎
algoName = self.algoEngine.addAlgo(algo)
self.writeLog(u'sendAlgoOrder(): addAlgo {}={}'.format(algoName, str(algo)))
order.status = STATUS_UNKNOWN
order.orderID = algoName
order.vtOrderID = gatewayName + u'.' + algoName
event1 = Event(type_=EVENT_ORDER)
event1.dict_['data'] = order
self.eventEngine.put(event1)
# 登记在本地的算法委托字典中
self.algo_order_dict.update({algoName: {'algo': algo, 'order': order}})
return gatewayName + u'.' + algoName
# ----------------------------------------------------------------------
def cancelOrder(self, cancelOrderReq, gatewayName):
"""对特定接口撤单"""
if gatewayName in self.gatewayDict:
if cancelOrderReq.orderID in self.algo_order_dict:
self.writeLog(u'执行算法实例撤单')
self.cancelAlgoOrder(cancelOrderReq,gatewayName)
return
if gatewayName in self.gatewayDict and gatewayName in self.connected_gw_names:
gateway = self.gatewayDict[gatewayName]
gateway.cancelOrder(cancelOrderReq)
else:
self.writeLog(text.GATEWAY_NOT_EXIST.format(gateway=gatewayName))
def cancelAlgoOrder(self, cancelOrderReq, gatewayName):
if self.algoEngine is None:
self.writeError(u'算法引擎未实例化,不能撤单:{}'.format(cancelOrderReq.orderID))
return
try:
d = self.algo_order_dict.get(cancelOrderReq.orderID)
algo = d.get('algo',None)
order = d.get('order', None)
if algo is None or order is None:
self.writeError(u'未找到算法配置和委托单,不能撤单:{}'.format(cancelOrderReq.orderID))
return
if cancelOrderReq.orderID not in self.algoEngine.algoDict:
self.writeError(u'算法实例不存在,不能撤单')
return
self.algoEngine.stopAlgo(cancelOrderReq.orderID)
order.cancelTime = datetime.now().strftime('%H:%M:%S.%f')
order.status = STATUS_CANCELLED
event1 = Event(type_=EVENT_ORDER)
event1.dict_['data'] = order
self.eventEngine.put(event1)
except Exception as ex:
self.writeError(u'算法实例撤销异常:{}\n{}'.format(str(ex),traceback.format_exc()))
# ----------------------------------------------------------------------
def qryAccount(self, gatewayName):
"""查询特定接口的账户"""
@ -330,7 +472,7 @@ class MainEngine(object):
return True
except Exception as ex:
print( u'vtEngine.disconnect Exception:{0} '.format(str(ex)))
self.writeError(u'vtEngine.disconnect Exception:{0} '.format(str(ex)))
return False
# ----------------------------------------------------------------------
@ -363,7 +505,7 @@ class MainEngine(object):
filename = os.path.abspath(os.path.join(path, 'vnpy'))
print( u'create logger:{}'.format(filename))
print(u'create logger:{}'.format(filename))
self.logger = setup_logger(filename=filename, name='vnpy', debug=True)
# ----------------------------------------------------------------------
@ -381,15 +523,22 @@ class MainEngine(object):
else:
self.createLogger()
# 发出邮件/微信
#try:
# if len(self.gatewayDetailList) > 0:
# target = self.gatewayDetailList[0]['gatewayName']
# else:
# target = WECHAT_GROUP["DEBUG_01"]
# sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_ERROR)
#except Exception as ex:
# print(u'send wechat exception:{}'.format(str(ex)),file=sys.stderr)
# 发出邮件
if globalSetting.get('activate_email', False):
try:
sendmail(subject=u'{0} Error'.format('_'.join(self.connected_gw_names)), msgcontent=content)
except Exception as ex:
print(u'vtEngine.writeError sendmail Exception:{}'.format(str(ex)), file=sys.stderr)
print(u'{}'.format(traceback.format_exc()), file=sys.stderr)
# 发出微信
if globalSetting.get('activate_wx_ft', False):
try:
from huafu.util.util_wx_ft import sendWxMsg
sendWxMsg(text=content)
except Exception as ex:
print(u'vtEngine.writeError sendWxMsg Exception:{}'.format(str(ex)), file=sys.stderr)
print(u'{}'.format(traceback.format_exc()), file=sys.stderr)
# ----------------------------------------------------------------------
def writeWarning(self, content):
@ -411,20 +560,22 @@ class MainEngine(object):
self.createLogger()
# 发出邮件
try:
sendmail(subject=u'{0} Warning'.format('_'.join(self.connected_gw_names)), msgcontent=content)
except:
pass
if globalSetting.get('activate_email', False):
try:
sendmail(subject=u'{0} Warning'.format('_'.join(self.connected_gw_names)), msgcontent=content)
except Exception as ex:
print(u'vtEngine.writeWarning sendmail Exception:{}'.format(str(ex)), file=sys.stderr)
print(u'{}'.format(traceback.format_exc()), file=sys.stderr)
# 发出微信
#try:
# if len(self.gatewayDetailList) > 0:
# target = self.gatewayDetailList[0]['gatewayName']
# else:
# target = WECHAT_GROUP["DEBUG_01"]
# sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_WARNING)
#except Exception as ex:
# print(u'send wechat exception:{}'.format(str(ex)), file=sys.stderr)
if globalSetting.get('activate_wx_ft', False):
try:
from huafu.util.util_wx_ft import sendWxMsg
sendWxMsg(text=content)
except Exception as ex:
print(u'vtEngine.writeWarning sendWxMsg Exception:{}'.format(str(ex)), file=sys.stderr)
print(u'{}'.format(traceback.format_exc()), file=sys.stderr)
# ----------------------------------------------------------------------
def writeNotification(self, content):
@ -436,20 +587,21 @@ class MainEngine(object):
self.eventEngine.put(event)
# 发出邮件
try:
sendmail(subject=u'{0} Notification'.format('_'.join(self.connected_gw_names)), msgcontent=content)
except:
pass
if globalSetting.get('activate_email', False):
try:
sendmail(subject=u'{0} Notification'.format('_'.join(self.connected_gw_names)), msgcontent=content)
except Exception as ex:
print(u'vtEngine.writeNotification sendmail Exception:{}'.format(str(ex)), file=sys.stderr)
print(u'{}'.format(traceback.format_exc()), file=sys.stderr)
# 发出微信
# try:
# if len(self.gatewayDetailList) > 0:
# target = self.gatewayDetailList[0]['gatewayName']
# else:
# target = WECHAT_GROUP["DEBUG_01"]
# sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_INFO)
# except Exception as ex:
# print(u'send wechat exception:{}'.format(str(ex)), file=sys.stderr)
if globalSetting.get('activate_wx_ft', False):
try:
from huafu.util.util_wx_ft import sendWxMsg
sendWxMsg(text=content)
except Exception as ex:
print(u'vtEngine.writeNotification sendWxMsg Exception:{}'.format(str(ex)), file=sys.stderr)
print(u'{}'.format(traceback.format_exc()), file=sys.stderr)
# ----------------------------------------------------------------------
def writeCritical(self, content):
@ -471,23 +623,22 @@ class MainEngine(object):
self.createLogger()
# 发出邮件
try:
sendmail(subject=u'{0} Critical'.format('_'.join(self.connected_gw_names)), msgcontent=content)
from vnpy.trader.util_wx_ft import sendWxMsg
sendWxMsg(text=content,desp='Critical error')
except:
pass
## 发出微信
#try:
# # if len(self.gatewayDetailList) > 0:
# target = self.gatewayDetailList[0]['gatewayName']
# else:
# target = WECHAT_GROUP["DEBUG_01"]
# sendWeChatMsg(content, target=target, level=WECHAT_LEVEL_FATAL)
#except:
# pass
if globalSetting.get('activate_email',False):
# 发出邮件
try:
sendmail(subject=u'{0} Critical'.format('_'.join(self.connected_gw_names)), msgcontent=content)
except Exception as ex:
print(u'vtEngine.writeCritical sendmail Exception:{}'.format(str(ex)), file=sys.stderr)
print(u'{}'.format(traceback.format_exc()), file=sys.stderr)
# 发出微信
if globalSetting.get('activate_wx_ft',False):
try:
from huafu.util.util_wx_ft import sendWxMsg
sendWxMsg(text=content)
except Exception as ex:
print(u'vtEngine.writeCritical sendWxMsg Exception:{}'.format(str(ex)), file=sys.stderr)
print(u'{}'.format(traceback.format_exc()), file=sys.stderr)
#
# ----------------------------------------------------------------------
def dbConnect(self):
@ -515,7 +666,6 @@ class MainEngine(object):
self.writeError(text.DATABASE_CONNECTING_FAILED)
self.db_has_connected = False
# ----------------------------------------------------------------------
def dbInsert(self, dbName, collectionName, d):
"""向MongoDB中插入数据d是具体数据"""
@ -648,7 +798,7 @@ class MainEngine(object):
self.writeError(u'dbQueryBySort exception:{}'.format(str(ex)))
return []
#----------------------------------------------------------------------
# ----------------------------------------------------------------------
def dbUpdate(self, dbName, collectionName, d, flt, upsert=False):
"""向MongoDB中更新数据d是具体数据flt是过滤条件upsert代表若无是否要插入"""
try:
@ -701,7 +851,7 @@ class MainEngine(object):
except Exception as ex:
self.writeError(u'dbDelete exception:{}'.format(str(ex)))
#----------------------------------------------------------------------
# ----------------------------------------------------------------------
def dbLogging(self, event):
"""向MongoDB中插入日志"""
log = event.dict_['data']
@ -802,7 +952,13 @@ class DataEngine(object):
# 保存合约详细信息的字典
self.contractDict = {}
# 本地自定义的合约详细配置字典
self.custom_contract_setting = {}
# 合约与自定义套利合约映射表
self.contract_spd_mapping = {}
# 保存委托数据的字典
self.orderDict = {}
@ -836,10 +992,14 @@ class DataEngine(object):
def getContract(self, vtSymbol):
"""查询合约对象"""
try:
return self.contractDict[vtSymbol]
except KeyError:
if vtSymbol in self.contractDict.keys():
return self.contractDict[vtSymbol]
return None
except Exception as ex:
print(str(ex),file=sys.stderr)
return None
# ----------------------------------------------------------------------
def getAllContracts(self):
"""查询所有合约对象(返回列表)"""
@ -863,7 +1023,28 @@ class DataEngine(object):
for key, value in d.items():
self.contractDict[key] = value
f.close()
c = Custom_Contract()
self.custom_contract_setting = c.get_config()
d = c.get_contracts()
if len(d)>0:
print(u'更新本地定制合约')
self.contractDict.update(d)
# 将leg1leg2合约对应的自定义spd合约映射
for spd_name in self.custom_contract_setting.keys():
setting = self.custom_contract_setting.get(spd_name)
leg1_symbol = setting.get('leg1_symbol')
leg2_symbol = setting.get('leg2_symbol')
for symbol in [leg1_symbol,leg2_symbol]:
spd_mapping_list = self.contract_spd_mapping.get(symbol,[])
# 更新映射
if spd_name not in spd_mapping_list:
spd_mapping_list.append(spd_name)
self.contract_spd_mapping.update({symbol:spd_mapping_list})
# ----------------------------------------------------------------------
def updateOrder(self, event):
"""更新委托数据"""
@ -938,8 +1119,7 @@ class DataEngine(object):
def updatePosition(self,event):
"""更新持仓信息"""
# 在获取更新持仓信息时自动订阅这个symbol
# 目的1、
# 1、在获取更新持仓信息时自动订阅这个symbol
position = event.dict_['data']
symbol = position.symbol
@ -974,3 +1154,59 @@ class DataEngine(object):
self.mainEngine.writeLog(u'自动订阅合约{0}'.format(symbol))
class Custom_Contract(object):
"""
定制合约
# 适用于初始化系统时,补充到本地合约信息文件中 contracts.vt
# 适用于CTP网关加载自定义的套利合约做内部行情撮合
"""
# 运行本地目录下定制合约的配置文件dict
file_name = 'Custom_Contracts.json'
custom_config_file = getJsonPath(file_name, __file__)
def __init__(self):
"""构造函数"""
self.setting = {} # 所有设置
try:
# 配置文件是否存在
if not os.path.exists(self.custom_config_file):
return
# 加载配置文件,兼容中文说明
with open(self.custom_config_file,'r',encoding='UTF-8') as f:
# 解析json文件
print(u'{}文件加载定制合约'.format(self.custom_config_file))
self.setting = json.load(f)
except IOError:
print('读取{} 出错'.format(self.custom_config_file),file=sys.stderr)
def get_config(self):
"""获取配置"""
return self.setting
def get_contracts(self):
"""获取所有合约信息"""
d = {}
for k,v in self.setting.items():
contract = VtContractData()
contract.gatewayName = v.get('gateway_name',None)
if contract.gatewayName is None and globalSetting.get('gateway_name',None) is not None:
contract.gatewayName = globalSetting.get('gateway_name')
contract.symbol = k
contract.exchange = v.get('exchange',None)
contract.vtSymbol = contract.symbol
contract.name = v.get('name',contract.symbol)
contract.size = v.get('size',100)
contract.priceTick = v.get('minDiff',0.01) # 最小跳动
contract.strikePrice = v.get('strike_price',None)
contract.underlyingSymbol = v.get('underlying_symbol')
contract.longMarginRatio = v.get('margin_rate',0.1)
contract.shortMarginRatio = v.get('margin_rate',0.1)
contract.productClass = v.get('product_class',None)
d[contract.vtSymbol] = contract
return d

View File

@ -7,14 +7,58 @@
import os,sys
import decimal
import json
from datetime import datetime
from datetime import datetime,timedelta
import importlib
import re
from functools import lru_cache
MAX_NUMBER = 10000000000000
MAX_DECIMAL = 8
def import_module_by_str(import_module_name):
"""
动态导入模块/函数
:param import_module_name:
:return:
"""
import traceback
from importlib import import_module, reload
# 参数检查
if len(import_module_name) == 0:
print('import_module_by_str parameter error,return None')
return None
print('trying to import {}'.format(import_module_name))
try:
import_name = str(import_module_name).replace(':', '.')
modules = import_name.split('.')
if len(modules) == 1:
mod = import_module(modules[0])
return mod
else:
loaded_modules = '.'.join(modules[0:-1])
print('import {}'.format(loaded_modules))
mod = import_module(loaded_modules)
comp = modules[-1]
if not hasattr(mod, comp):
loaded_modules = '.'.join([loaded_modules,comp])
print('realod {}'.format(loaded_modules))
mod = reload(loaded_modules)
else:
print('from {} import {}'.format(loaded_modules,comp))
mod = getattr(mod, comp)
return mod
except Exception as ex:
print('import {} fail,{},{}'.format(import_module_name,str(ex),traceback.format_exc()))
return None
def floatToStr(float_str):
"""格式化显示浮点字符串去除后面的0"""
if '.' in float_str:
@ -30,6 +74,45 @@ def floatToStr(float_str):
else:
return float_str
# ----------------------------------------------------------------------
def roundToPriceTick(priceTick, price):
"""取整价格到合约最小价格变动"""
if not priceTick:
return price
if price > 0:
# 根据最小跳动取整
newPrice = price - price % priceTick
else:
# 兼容套利品种的负数价格
newPrice = round(price / priceTick, 0) * priceTick
# 数字货币,对浮点的长度有要求,需要砍除多余
if isinstance(priceTick,float):
price_exponent = decimal.Decimal(str(newPrice))
tick_exponent = decimal.Decimal(str(priceTick))
if abs(price_exponent.as_tuple().exponent) > abs(tick_exponent.as_tuple().exponent):
newPrice = round(newPrice, ndigits=abs(tick_exponent.as_tuple().exponent))
newPrice = float(str(newPrice))
return newPrice
def roundToVolumeTick(volumeTick,volume):
if volumeTick == 0:
return volume
# 取整
newVolume = volume - volume % volumeTick
if isinstance(volumeTick,float):
v_exponent = decimal.Decimal(str(newVolume))
vt_exponent = decimal.Decimal(str(volumeTick))
if abs(v_exponent.as_tuple().exponent) > abs(vt_exponent.as_tuple().exponent):
newVolume = round(newVolume, ndigits=abs(vt_exponent.as_tuple().exponent))
newVolume = float(str(newVolume))
return newVolume
@lru_cache()
def getShortSymbol(symbol):
"""取得合约的短号"""
# 套利合约
@ -55,15 +138,25 @@ def getShortSymbol(symbol):
return shortSymbol.group(1)
@lru_cache()
def getFullSymbol(symbol):
"""获取全路径得合约名称"""
"""
获取全路径得合约名称
"""
if symbol.endswith('SPD'):
return symbol
short_symbol = getShortSymbol(symbol)
if short_symbol == symbol:
return symbol
symbol_month = symbol.replace(short_symbol, '')
if len(symbol_month) == 3:
return '{0}1{1}'.format(short_symbol, symbol_month)
if symbol_month[0] == '0':
# 支持2020年合约
return '{0}2{1}'.format(short_symbol, symbol_month)
else:
return '{0}1{1}'.format(short_symbol, symbol_month)
else:
return symbol
@ -116,6 +209,14 @@ def safeUnicode(value):
return value
def get_tdx_market_code(code):
# 获取通达信股票的market code
code = str(code)
if code[0] in ['5', '6', '9'] or code[:3] in ["009", "126", "110", "201", "202", "203", "204"]:
return 1
return 0
#----------------------------------------------------------------------
def loadMongoSetting():
"""载入MongoDB数据库的配置"""
@ -137,6 +238,31 @@ def todayDate():
"""获取当前本机电脑时间的日期"""
return datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
def getTradingDate(dt=None):
"""
根据输入的时间返回交易日的日期
:param dt:
:return:
"""
tradingDay = ''
if dt is None:
dt = datetime.now()
if dt.hour >= 21:
if dt.isoweekday() == 5:
# 星期五=》星期一
return (dt + timedelta(days=3)).strftime('%Y-%m-%d')
else:
# 第二天
return (dt + timedelta(days=1)).strftime('%Y-%m-%d')
elif dt.hour < 8 and dt.isoweekday() == 6:
# 星期六=>星期一
return (dt + timedelta(days=2)).strftime('%Y-%m-%d')
else:
return dt.strftime('%Y-%m-%d')
# 图标路径
iconPathDict = {}
@ -163,6 +289,12 @@ def getTempPath(name):
path = os.path.join(tempPath, name)
return path
def get_data_path():
"""获取存放数据文件的路径"""
data_path = os.path.join(os.getcwd(),'data')
if not os.path.exists(data_path):
os.makedirs(data_path)
return data_path
# JSON配置文件路径
jsonPathDict = {}
@ -188,14 +320,6 @@ def getJsonPath(name, moduleFile):
return moduleJsonPath
# ----------------------------- 扩展的功能 ---------
try:
import openpyxl
from openpyxl.utils.dataframe import dataframe_to_rows
from openpyxl.drawing.image import Image
except:
print(u'can not import openpyxl',file=sys.stderr)
def save_df_to_excel(file_name, sheet_name, df):
"""
保存dataframe到execl
@ -207,6 +331,14 @@ def save_df_to_excel(file_name, sheet_name, df):
if file_name is None or sheet_name is None or df is None:
return False
# ----------------------------- 扩展的功能 ---------
try:
import openpyxl
from openpyxl.utils.dataframe import dataframe_to_rows
from openpyxl.drawing.image import Image
except:
print(u'can not import openpyxl', file=sys.stderr)
if 'openpyxl' not in sys.modules:
print(u'can not import openpyxl', file=sys.stderr)
return False
@ -254,6 +386,14 @@ def save_text_to_excel(file_name, sheet_name, text):
if file_name is None or len(sheet_name)==0 or len(text) == 0 :
return False
# ----------------------------- 扩展的功能 ---------
try:
import openpyxl
from openpyxl.utils.dataframe import dataframe_to_rows
from openpyxl.drawing.image import Image
except:
print(u'can not import openpyxl', file=sys.stderr)
if 'openpyxl' not in sys.modules:
return False
@ -300,6 +440,13 @@ def save_images_to_excel(file_name, sheet_name, image_names):
"""
if file_name is None or len(sheet_name) == 0 or len(image_names) == 0:
return False
# ----------------------------- 扩展的功能 ---------
try:
import openpyxl
from openpyxl.utils.dataframe import dataframe_to_rows
from openpyxl.drawing.image import Image
except:
print(u'can not import openpyxl', file=sys.stderr)
if 'openpyxl' not in sys.modules:
return False
@ -349,4 +496,57 @@ def save_images_to_excel(file_name, sheet_name, image_names):
except Exception as ex:
import traceback
print(u'save_images_to_excel exception:{}'.format(str(ex)), traceback.format_exc(),file=sys.stderr)
return False
return False
def display_dual_axis(df, columns1, columns2=[], invert_yaxis1=False, invert_yaxis2=False, file_name=None, sheet_name=None,
image_name=None):
"""
显示(保存)双Y轴的走势图
:param df: DataFrame
:param columns1: y1轴
:param columns2: Y2轴
:param invert_yaxis1: Y1 轴反转
:param invert_yaxis2: Y2 轴翻转
:param file_name: 保存的excel 文件名称
:param sheet_name: excel 的sheet
:param image_name: 保存的image 文件名
:return:
"""
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = (20.0, 10.0)
df1 = df[columns1]
df1.index = list(range(len(df)))
fig, ax1 = plt.subplots()
if invert_yaxis1:
ax1.invert_yaxis()
ax1.plot(df1)
if len(columns2) > 0:
df2 = df[columns2]
df2.index = list(range(len(df)))
ax2 = ax1.twinx()
if invert_yaxis2:
ax2.invert_yaxis()
ax2.plot(df2)
# 修改x轴得label为时间
xt = ax1.get_xticks()
xt2 = [df.index[int(i)] for i in xt[1:-2]]
xt2.insert(0, '')
xt2.append('')
ax1.set_xticklabels(xt2)
# 是否保存图片到文件
if image_name is not None:
fig = plt.gcf()
fig.savefig(image_name, bbox_inches='tight')
# 插入图片到指定的excel文件sheet中并保存excel
if file_name is not None and sheet_name is not None:
save_images_to_excel(file_name, sheet_name, [image_name])
else:
plt.show()