diff --git a/examples/VnTrader/run.py b/examples/VnTrader/run.py index b2975c27..53872f1f 100644 --- a/examples/VnTrader/run.py +++ b/examples/VnTrader/run.py @@ -35,7 +35,8 @@ elif system == 'Windows': # 加载上层应用 from vnpy.trader.app import (riskManager, ctaStrategy, - spreadTrading, algoTrading) + spreadTrading, algoTrading, + tradeCopy) #---------------------------------------------------------------------- @@ -67,6 +68,7 @@ def main(): me.addApp(ctaStrategy) me.addApp(spreadTrading) me.addApp(algoTrading) + me.addApp(tradeCopy) # 创建主窗口 mw = MainWindow(me, ee) diff --git a/vnpy/__init__.py b/vnpy/__init__.py index 708620bf..4d027381 100644 --- a/vnpy/__init__.py +++ b/vnpy/__init__.py @@ -1,4 +1,4 @@ # encoding: UTF-8 -__version__ = '1.9.0' +__version__ = '1.9.2' __author__ = 'Xiaoyou Chen' \ No newline at end of file diff --git a/vnpy/trader/app/tradeCopy/__init__.py b/vnpy/trader/app/tradeCopy/__init__.py new file mode 100644 index 00000000..6f445281 --- /dev/null +++ b/vnpy/trader/app/tradeCopy/__init__.py @@ -0,0 +1,10 @@ +# encoding: UTF-8 + +from .tcEngine import TcEngine +from .uiTcWidget import TcManager + +appName = 'TradeCopy' +appDisplayName = u'交易复制' +appEngine = TcEngine +appWidget = TcManager +appIco = 'tc.ico' \ No newline at end of file diff --git a/vnpy/trader/app/tradeCopy/tc.ico b/vnpy/trader/app/tradeCopy/tc.ico new file mode 100644 index 00000000..e7ee47a9 Binary files /dev/null and b/vnpy/trader/app/tradeCopy/tc.ico differ diff --git a/vnpy/trader/app/tradeCopy/tcEngine.py b/vnpy/trader/app/tradeCopy/tcEngine.py new file mode 100644 index 00000000..b6493137 --- /dev/null +++ b/vnpy/trader/app/tradeCopy/tcEngine.py @@ -0,0 +1,319 @@ +# encoding: UTF-8 + +from collections import defaultdict + +from vnpy.event import Event +from vnpy.rpc import RpcClient, RpcServer +from vnpy.trader.vtEvent import EVENT_POSITION, EVENT_TRADE, EVENT_TIMER +from vnpy.trader.vtConstant import (DIRECTION_LONG, DIRECTION_SHORT, + OFFSET_OPEN, OFFSET_CLOSE, PRICETYPE_LIMITPRICE, + OFFSET_CLOSEYESTERDAY, OFFSET_CLOSETODAY) +from vnpy.trader.vtObject import VtOrderReq, VtCancelOrderReq, VtLogData, VtSubscribeReq + + +EVENT_TC_LOG = 'eTcLog' + + +######################################################################## +class TcEngine(object): + """交易复制引擎""" + MODE_PROVIDER = 1 + MODE_SUBSCRIBER = 2 + + #---------------------------------------------------------------------- + def __init__(self, mainEngine, eventEngine): + """Constructor""" + self.mainEngine = mainEngine + self.eventEngine = eventEngine + + self.mode = None # Subscriber/Provider + self.posDict = defaultdict(int) # vtPositionName:int + self.targetDict = defaultdict(int) # vtPositionName:int + self.copyRatio = 1 + self.interval = 1 + self.subscribeSet = set() + + self.count = 0 + self.server = None # RPC Server + self.client = None # RPC Client + + self.registerEvent() + + #---------------------------------------------------------------------- + def startProvider(self, repAddress, pubAddress, interval): + """""" + self.mode = self.MODE_PROVIDER + self.interval = interval + + self.server = RpcServer(repAddress, pubAddress) + self.server.usePickle() + self.server.register(self.getPos) + self.server.start() + + self.writeLog(u'启动发布者模式') + + #---------------------------------------------------------------------- + def startSubscriber(self, reqAddress, subAddress, copyRatio): + """""" + self.mode = self.MODE_SUBSCRIBER + self.copyRatio = copyRatio + + self.client = TcClient(self, reqAddress, subAddress) + self.client.usePickle() + self.client.subscribeTopic('') + self.client.start() + + self.writeLog(u'启动订阅者模式,运行时请不要执行其他交易操作') + + self.initTarget() + + #---------------------------------------------------------------------- + def stop(self): + """""" + if self.client: + self.client.stop() + self.writeLog(u'订阅者模式已停止') + + if self.server: + self.server.stop() + self.writeLog(u'发布者模式已停止') + + self.mode = None + + #---------------------------------------------------------------------- + def registerEvent(self): + """""" + self.eventEngine.register(EVENT_POSITION, self.processPositionEvent) + self.eventEngine.register(EVENT_TRADE, self.processTradeEvent) + self.eventEngine.register(EVENT_TIMER, self.processTimerEvent) + + #---------------------------------------------------------------------- + def checkAndTrade(self, vtSymbol): + """""" + if self.checkNoWorkingOrder(vtSymbol): + self.newOrder(vtSymbol) + else: + self.cancelOrder(vtSymbol) + + #---------------------------------------------------------------------- + def processTimerEvent(self, event): + """""" + if self.mode != self.MODE_PROVIDER: + return + + self.count += 1 + if self.count < self.interval: + return + self.count = 0 + + for vtPositionName in self.posDict.keys(): + self.publishPos(vtPositionName) + + #---------------------------------------------------------------------- + def processTradeEvent(self, event): + """""" + trade = event.dict_['data'] + + vtPositionName = '.'.join([trade.vtSymbol, trade.direction]) + + if trade.offset == OFFSET_OPEN: + self.posDict[vtPositionName] += trade.volume + else: + self.posDict[vtPositionName] -= trade.volume + + if self.mode == self.MODE_PROVIDER: + self.publishPos(vtPositionName) + + #---------------------------------------------------------------------- + def processPositionEvent(self, event): + """""" + position = event.dict_['data'] + self.posDict[position.vtPositionName] = position.position + + #---------------------------------------------------------------------- + def publishPos(self, vtPositionName): + """""" + l = vtPositionName.split('.') + direction = l[-1] + vtSymbol = vtPositionName.replace('.' + direction, '') + + data = { + 'vtSymbol': vtSymbol, + 'vtPositionName': vtPositionName, + 'pos': self.posDict[vtPositionName] + } + self.server.publish('', data) + + #---------------------------------------------------------------------- + def updatePos(self, data): + """""" + vtSymbol = data['vtSymbol'] + if vtSymbol not in self.subscribeSet: + contract = self.mainEngine.getContract(vtSymbol) + + req = VtSubscribeReq() + req.symbol = contract.symbol + req.exchange = contract.exchange + self.mainEngine.subscribe(req, contract.gatewayName) + + vtPositionName = data['vtPositionName'] + target = int(data['pos'] * self.copyRatio) + self.targetDict[vtPositionName] = target + + self.checkAndTrade(vtSymbol) + + #---------------------------------------------------------------------- + def newOrder(self, vtSymbol): + """""" + for vtPositionName in self.targetDict.keys(): + if vtSymbol not in vtPositionName: + continue + + pos = self.posDict[vtPositionName] + target = self.targetDict[vtPositionName] + if pos == target: + continue + + contract = self.mainEngine.getContract(vtSymbol) + tick = self.mainEngine.getTick(vtSymbol) + if not tick: + return + + req = VtOrderReq() + req.symbol = contract.symbol + req.exchange = contract.exchange + req.priceType = PRICETYPE_LIMITPRICE + req.volume = abs(target - pos) + + # Open position + if target > pos: + req.offset = OFFSET_OPEN + + if DIRECTION_LONG in vtPositionName: + req.direction = DIRECTION_LONG + if tick.upperLimit: + req.price = tick.upperLimit + else: + req.price = tick.askPrice1 + elif DIRECTION_SHROT in vtPositionName: + req.direction = DIRECTION_SHROT + if tick.lowerLimit: + req.price = tick.lowerLimit + else: + req.price = tick.bidPrice1 + + self.mainEngine.sendOrder(req, contract.gatewayName) + + # Close position + elif target < pos: + req.offset = OFFSET_CLOSE + + if DIRECTION_LONG in vtPositionName: + req.direction = DIRECTION_SHROT + if tick.upperLimit: + req.price = tick.upperLimit + else: + req.price = tick.askPrice1 + + elif DIRECTION_SHROT in vtPositionName: + req.direction = DIRECTION_LONG + if tick.lowerLimit: + req.price = tick.lowerLimit + else: + req.price = tick.bidPrice1 + + # Use auto-convert for solving today/yesterday position problem + reqList = self.mainEngine.convertOrderReq(req) + for convertedReq in reqList: + self.mainEngine.sendOrder(convertedReq, contract.gatewayName) + + # Write log + msg = u'发出%s委托 %s%s %s@%s' %(vtSymbol, req.direction, req.offset, + req.price, req.volume) + self.writeLog(msg) + + #---------------------------------------------------------------------- + def cancelOrder(self, vtSymbol): + """ + Cancel all orders of a certain vtSymbol + """ + l = self.mainEngine.getAllWorkingOrders() + for order in l: + if order.vtSymbol == vtSymbol: + req = VtCancelOrderReq() + req.orderID = order.orderID + req.frontID = order.frontID + req.sessionID = order.sessionID + req.symbol = order.symbol + req.exchange = order.exchange + self.mainEngine.cancelOrder(req, order.gatewayName) + + self.writeLog(u'撤销%s全部活动中委托' %vtSymbol) + + #---------------------------------------------------------------------- + def checkNoWorkingOrder(self, vtSymbol): + """ + Check if there is still any working orders of a certain vtSymbol + """ + l = self.mainEngine.getAllWorkingOrders() + for order in l: + if order.vtSymbol == vtSymbol: + return False + + return True + + #---------------------------------------------------------------------- + def writeLog(self, msg): + """""" + log = VtLogData() + log.logContent = msg + + event = Event(EVENT_TC_LOG) + event.dict_['data'] = log + self.eventEngine.put(event) + + #---------------------------------------------------------------------- + def getPos(self): + """ + Get currenct position data of provider + """ + return dict(self.posDict) + + #---------------------------------------------------------------------- + def initTarget(self): + """ + Init target data of subscriber based on position data from provider + """ + d = self.client.getPos() + for vtPositionName, pos in d.items(): + l = vtPositionName.split('.') + direction = l[-1] + vtSymbol = vtPositionName.replace('.' + direction, '') + + data = { + 'vtPositionName': vtPositionName, + 'vtSymbol': vtSymbol, + 'pos': pos + } + self.updatePos(data) + + self.writeLog(u'目标仓位初始化完成') + + +######################################################################## +class TcClient(RpcClient): + """""" + + #---------------------------------------------------------------------- + def __init__(self, engine, reqAddress, subAddress): + """Constructor""" + super(TcClient, self).__init__(reqAddress, subAddress) + + self.engine = engine + + #---------------------------------------------------------------------- + def callback(self, topic, data): + """""" + self.engine.updatePos(data) + + \ No newline at end of file diff --git a/vnpy/trader/app/tradeCopy/uiTcWidget.py b/vnpy/trader/app/tradeCopy/uiTcWidget.py new file mode 100644 index 00000000..5cdfc9d4 --- /dev/null +++ b/vnpy/trader/app/tradeCopy/uiTcWidget.py @@ -0,0 +1,199 @@ +# encoding: UTF-8 + +import shelve + +from vnpy.event import Event +from vnpy.trader.uiQt import QtCore, QtGui, QtWidgets +from vnpy.trader.vtFunction import getTempPath + +from .tcEngine import EVENT_TC_LOG + + +######################################################################## +class TcManager(QtWidgets.QWidget): + """""" + REQ_ADDRESS = 'tcp://localhost:2015' + SUB_ADDRESS = 'tcp://localhost:2018' + REP_ADDRESS = 'tcp://*:2015' + PUB_ADDRESS = 'tcp://*:2018' + COPY_RATIO = '1' + INTERVAL = '1' + + settingFileName = 'TradeCopy.vt' + settingFilePath = getTempPath(settingFileName) + + signal = QtCore.Signal(type(Event())) + + #---------------------------------------------------------------------- + def __init__(self, tcEngine, eventEngine, parent=None): + """Constructor""" + super(TcManager, self).__init__(parent) + + self.tcEngine = tcEngine + self.eventEngine = eventEngine + + self.initUi() + self.loadSetting() + self.registerEvent() + + self.tcEngine.writeLog(u'欢迎使用TradeCopy交易复制模块') + + #---------------------------------------------------------------------- + def initUi(self): + """""" + self.setWindowTitle(u'交易复制') + self.setMinimumWidth(700) + self.setMinimumHeight(700) + + # 创建组件 + self.lineReqAddress = QtWidgets.QLineEdit(self.REQ_ADDRESS) + self.lineSubAddress= QtWidgets.QLineEdit(self.SUB_ADDRESS) + self.lineRepAddress = QtWidgets.QLineEdit(self.REP_ADDRESS) + self.linePubAddress = QtWidgets.QLineEdit(self.PUB_ADDRESS) + + validator = QtGui.QDoubleValidator() + validator.setBottom(0) + self.lineCopyRatio = QtWidgets.QLineEdit() + self.lineCopyRatio.setValidator(validator) + self.lineCopyRatio.setText(self.COPY_RATIO) + + validator2 = QtGui.QIntValidator() + validator2.setBottom(1) + self.lineInterval = QtWidgets.QLineEdit() + self.lineInterval.setValidator(validator2) + self.lineInterval.setText(self.INTERVAL) + + self.buttonProvider = QtWidgets.QPushButton(u'启动发布者') + self.buttonProvider.clicked.connect(self.startProvider) + + self.buttonSubscriber = QtWidgets.QPushButton(u'启动订阅者') + self.buttonSubscriber.clicked.connect(self.startSubscriber) + + self.buttonStopEngine = QtWidgets.QPushButton(u'停止') + self.buttonStopEngine.clicked.connect(self.stopEngine) + self.buttonStopEngine.setEnabled(False) + + self.buttonResetAddress = QtWidgets.QPushButton(u'重置地址') + self.buttonResetAddress.clicked.connect(self.resetAddress) + + self.logMonitor = QtWidgets.QTextEdit() + self.logMonitor.setReadOnly(True) + + self.widgetList = [ + self.lineCopyRatio, + self.lineInterval, + self.linePubAddress, + self.lineSubAddress, + self.lineRepAddress, + self.lineReqAddress, + self.buttonProvider, + self.buttonSubscriber, + self.buttonResetAddress + ] + + # 布局 + QLabel = QtWidgets.QLabel + grid = QtWidgets.QGridLayout() + + grid.addWidget(QLabel(u'响应地址'), 0, 0) + grid.addWidget(self.lineRepAddress, 0, 1) + grid.addWidget(QLabel(u'请求地址'), 0, 2) + grid.addWidget(self.lineReqAddress, 0, 3) + + grid.addWidget(QLabel(u'发布地址'), 1, 0) + grid.addWidget(self.linePubAddress, 1, 1) + grid.addWidget(QLabel(u'订阅地址'), 1, 2) + grid.addWidget(self.lineSubAddress, 1, 3) + + grid.addWidget(QLabel(u'发布间隔(秒)'), 2, 0) + grid.addWidget(self.lineInterval, 2, 1) + grid.addWidget(QLabel(u'复制比例(倍)'), 2, 2) + grid.addWidget(self.lineCopyRatio, 2, 3) + + grid.addWidget(self.buttonProvider, 3, 0, 1, 2) + grid.addWidget(self.buttonSubscriber, 3, 2, 1, 2) + grid.addWidget(self.buttonStopEngine, 4, 0, 1, 2) + grid.addWidget(self.buttonResetAddress, 4, 2, 1, 2) + + grid.addWidget(self.logMonitor, 5, 0, 1, 4) + + self.setLayout(grid) + + #---------------------------------------------------------------------- + def saveSetting(self): + """""" + f = shelve.open(self.settingFilePath) + f['repAddress'] = self.lineRepAddress.text() + f['reqAddress'] = self.lineReqAddress.text() + f['pubAddress'] = self.linePubAddress.text() + f['subAddress'] = self.lineSubAddress.text() + f['copyRatio'] = self.lineCopyRatio.text() + f['interval'] = self.lineInterval.text() + f.close() + + #---------------------------------------------------------------------- + def loadSetting(self): + """""" + f = shelve.open(self.settingFilePath) + if f: + self.lineRepAddress.setText(f['repAddress']) + self.lineReqAddress.setText(f['reqAddress']) + self.linePubAddress.setText(f['pubAddress']) + self.lineSubAddress.setText(f['subAddress']) + self.lineCopyRatio.setText(f['copyRatio']) + self.lineInterval.setText(f['interval']) + f.close() + + #---------------------------------------------------------------------- + def resetAddress(self): + """""" + self.lineReqAddress.setText(self.REQ_ADDRESS) + self.lineRepAddress.setText(self.REP_ADDRESS) + self.linePubAddress.setText(self.PUB_ADDRESS) + self.lineSubAddress.setText(self.SUB_ADDRESS) + + #---------------------------------------------------------------------- + def stopEngine(self): + """""" + self.tcEngine.stop() + + for widget in self.widgetList: + widget.setEnabled(True) + self.buttonStopEngine.setEnabled(False) + + #---------------------------------------------------------------------- + def registerEvent(self): + """""" + self.signal.connect(self.processLogEvent) + self.eventEngine.register(EVENT_TC_LOG, self.signal.emit) + + #---------------------------------------------------------------------- + def processLogEvent(self, event): + """""" + log = event.dict_['data'] + txt = '%s: %s' %(log.logTime, log.logContent) + self.logMonitor.append(txt) + + #---------------------------------------------------------------------- + def startProvider(self): + """""" + repAddress = str(self.lineRepAddress.text()) + pubAddress = str(self.linePubAddress.text()) + interval = int(self.lineInterval.text()) + self.tcEngine.startProvider(repAddress, pubAddress, interval) + + for widget in self.widgetList: + widget.setEnabled(False) + self.buttonStopEngine.setEnabled(True) + + #---------------------------------------------------------------------- + def startSubscriber(self): + """""" + reqAddress = str(self.lineReqAddress.text()) + subAddress = str(self.lineSubAddress.text()) + copyRatio = float(self.lineCopyRatio.text()) + self.tcEngine.startSubscriber(reqAddress, subAddress, copyRatio) + + for widget in self.widgetList: + widget.setEnabled(False) + self.buttonStopEngine.setEnabled(True) \ No newline at end of file