[Add]新增交易复制模块TradeCopy

This commit is contained in:
vn.py 2018-12-02 13:24:04 +08:00
parent 31784f32b6
commit 644d14881c
6 changed files with 532 additions and 2 deletions

View File

@ -35,7 +35,8 @@ elif system == 'Windows':
# 加载上层应用
from vnpy.trader.app import (riskManager, ctaStrategy,
spreadTrading, algoTrading)
spreadTrading, algoTrading,
tradeCopy)
#----------------------------------------------------------------------
@ -67,6 +68,7 @@ def main():
me.addApp(ctaStrategy)
me.addApp(spreadTrading)
me.addApp(algoTrading)
me.addApp(tradeCopy)
# 创建主窗口
mw = MainWindow(me, ee)

View File

@ -1,4 +1,4 @@
# encoding: UTF-8
__version__ = '1.9.0'
__version__ = '1.9.2'
__author__ = 'Xiaoyou Chen'

View File

@ -0,0 +1,10 @@
# encoding: UTF-8
from .tcEngine import TcEngine
from .uiTcWidget import TcManager
appName = 'TradeCopy'
appDisplayName = u'交易复制'
appEngine = TcEngine
appWidget = TcManager
appIco = 'tc.ico'

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

View File

@ -0,0 +1,319 @@
# encoding: UTF-8
from collections import defaultdict
from vnpy.event import Event
from vnpy.rpc import RpcClient, RpcServer
from vnpy.trader.vtEvent import EVENT_POSITION, EVENT_TRADE, EVENT_TIMER
from vnpy.trader.vtConstant import (DIRECTION_LONG, DIRECTION_SHORT,
OFFSET_OPEN, OFFSET_CLOSE, PRICETYPE_LIMITPRICE,
OFFSET_CLOSEYESTERDAY, OFFSET_CLOSETODAY)
from vnpy.trader.vtObject import VtOrderReq, VtCancelOrderReq, VtLogData, VtSubscribeReq
EVENT_TC_LOG = 'eTcLog'
########################################################################
class TcEngine(object):
"""交易复制引擎"""
MODE_PROVIDER = 1
MODE_SUBSCRIBER = 2
#----------------------------------------------------------------------
def __init__(self, mainEngine, eventEngine):
"""Constructor"""
self.mainEngine = mainEngine
self.eventEngine = eventEngine
self.mode = None # Subscriber/Provider
self.posDict = defaultdict(int) # vtPositionName:int
self.targetDict = defaultdict(int) # vtPositionName:int
self.copyRatio = 1
self.interval = 1
self.subscribeSet = set()
self.count = 0
self.server = None # RPC Server
self.client = None # RPC Client
self.registerEvent()
#----------------------------------------------------------------------
def startProvider(self, repAddress, pubAddress, interval):
""""""
self.mode = self.MODE_PROVIDER
self.interval = interval
self.server = RpcServer(repAddress, pubAddress)
self.server.usePickle()
self.server.register(self.getPos)
self.server.start()
self.writeLog(u'启动发布者模式')
#----------------------------------------------------------------------
def startSubscriber(self, reqAddress, subAddress, copyRatio):
""""""
self.mode = self.MODE_SUBSCRIBER
self.copyRatio = copyRatio
self.client = TcClient(self, reqAddress, subAddress)
self.client.usePickle()
self.client.subscribeTopic('')
self.client.start()
self.writeLog(u'启动订阅者模式,运行时请不要执行其他交易操作')
self.initTarget()
#----------------------------------------------------------------------
def stop(self):
""""""
if self.client:
self.client.stop()
self.writeLog(u'订阅者模式已停止')
if self.server:
self.server.stop()
self.writeLog(u'发布者模式已停止')
self.mode = None
#----------------------------------------------------------------------
def registerEvent(self):
""""""
self.eventEngine.register(EVENT_POSITION, self.processPositionEvent)
self.eventEngine.register(EVENT_TRADE, self.processTradeEvent)
self.eventEngine.register(EVENT_TIMER, self.processTimerEvent)
#----------------------------------------------------------------------
def checkAndTrade(self, vtSymbol):
""""""
if self.checkNoWorkingOrder(vtSymbol):
self.newOrder(vtSymbol)
else:
self.cancelOrder(vtSymbol)
#----------------------------------------------------------------------
def processTimerEvent(self, event):
""""""
if self.mode != self.MODE_PROVIDER:
return
self.count += 1
if self.count < self.interval:
return
self.count = 0
for vtPositionName in self.posDict.keys():
self.publishPos(vtPositionName)
#----------------------------------------------------------------------
def processTradeEvent(self, event):
""""""
trade = event.dict_['data']
vtPositionName = '.'.join([trade.vtSymbol, trade.direction])
if trade.offset == OFFSET_OPEN:
self.posDict[vtPositionName] += trade.volume
else:
self.posDict[vtPositionName] -= trade.volume
if self.mode == self.MODE_PROVIDER:
self.publishPos(vtPositionName)
#----------------------------------------------------------------------
def processPositionEvent(self, event):
""""""
position = event.dict_['data']
self.posDict[position.vtPositionName] = position.position
#----------------------------------------------------------------------
def publishPos(self, vtPositionName):
""""""
l = vtPositionName.split('.')
direction = l[-1]
vtSymbol = vtPositionName.replace('.' + direction, '')
data = {
'vtSymbol': vtSymbol,
'vtPositionName': vtPositionName,
'pos': self.posDict[vtPositionName]
}
self.server.publish('', data)
#----------------------------------------------------------------------
def updatePos(self, data):
""""""
vtSymbol = data['vtSymbol']
if vtSymbol not in self.subscribeSet:
contract = self.mainEngine.getContract(vtSymbol)
req = VtSubscribeReq()
req.symbol = contract.symbol
req.exchange = contract.exchange
self.mainEngine.subscribe(req, contract.gatewayName)
vtPositionName = data['vtPositionName']
target = int(data['pos'] * self.copyRatio)
self.targetDict[vtPositionName] = target
self.checkAndTrade(vtSymbol)
#----------------------------------------------------------------------
def newOrder(self, vtSymbol):
""""""
for vtPositionName in self.targetDict.keys():
if vtSymbol not in vtPositionName:
continue
pos = self.posDict[vtPositionName]
target = self.targetDict[vtPositionName]
if pos == target:
continue
contract = self.mainEngine.getContract(vtSymbol)
tick = self.mainEngine.getTick(vtSymbol)
if not tick:
return
req = VtOrderReq()
req.symbol = contract.symbol
req.exchange = contract.exchange
req.priceType = PRICETYPE_LIMITPRICE
req.volume = abs(target - pos)
# Open position
if target > pos:
req.offset = OFFSET_OPEN
if DIRECTION_LONG in vtPositionName:
req.direction = DIRECTION_LONG
if tick.upperLimit:
req.price = tick.upperLimit
else:
req.price = tick.askPrice1
elif DIRECTION_SHROT in vtPositionName:
req.direction = DIRECTION_SHROT
if tick.lowerLimit:
req.price = tick.lowerLimit
else:
req.price = tick.bidPrice1
self.mainEngine.sendOrder(req, contract.gatewayName)
# Close position
elif target < pos:
req.offset = OFFSET_CLOSE
if DIRECTION_LONG in vtPositionName:
req.direction = DIRECTION_SHROT
if tick.upperLimit:
req.price = tick.upperLimit
else:
req.price = tick.askPrice1
elif DIRECTION_SHROT in vtPositionName:
req.direction = DIRECTION_LONG
if tick.lowerLimit:
req.price = tick.lowerLimit
else:
req.price = tick.bidPrice1
# Use auto-convert for solving today/yesterday position problem
reqList = self.mainEngine.convertOrderReq(req)
for convertedReq in reqList:
self.mainEngine.sendOrder(convertedReq, contract.gatewayName)
# Write log
msg = u'发出%s委托 %s%s %s@%s' %(vtSymbol, req.direction, req.offset,
req.price, req.volume)
self.writeLog(msg)
#----------------------------------------------------------------------
def cancelOrder(self, vtSymbol):
"""
Cancel all orders of a certain vtSymbol
"""
l = self.mainEngine.getAllWorkingOrders()
for order in l:
if order.vtSymbol == vtSymbol:
req = VtCancelOrderReq()
req.orderID = order.orderID
req.frontID = order.frontID
req.sessionID = order.sessionID
req.symbol = order.symbol
req.exchange = order.exchange
self.mainEngine.cancelOrder(req, order.gatewayName)
self.writeLog(u'撤销%s全部活动中委托' %vtSymbol)
#----------------------------------------------------------------------
def checkNoWorkingOrder(self, vtSymbol):
"""
Check if there is still any working orders of a certain vtSymbol
"""
l = self.mainEngine.getAllWorkingOrders()
for order in l:
if order.vtSymbol == vtSymbol:
return False
return True
#----------------------------------------------------------------------
def writeLog(self, msg):
""""""
log = VtLogData()
log.logContent = msg
event = Event(EVENT_TC_LOG)
event.dict_['data'] = log
self.eventEngine.put(event)
#----------------------------------------------------------------------
def getPos(self):
"""
Get currenct position data of provider
"""
return dict(self.posDict)
#----------------------------------------------------------------------
def initTarget(self):
"""
Init target data of subscriber based on position data from provider
"""
d = self.client.getPos()
for vtPositionName, pos in d.items():
l = vtPositionName.split('.')
direction = l[-1]
vtSymbol = vtPositionName.replace('.' + direction, '')
data = {
'vtPositionName': vtPositionName,
'vtSymbol': vtSymbol,
'pos': pos
}
self.updatePos(data)
self.writeLog(u'目标仓位初始化完成')
########################################################################
class TcClient(RpcClient):
""""""
#----------------------------------------------------------------------
def __init__(self, engine, reqAddress, subAddress):
"""Constructor"""
super(TcClient, self).__init__(reqAddress, subAddress)
self.engine = engine
#----------------------------------------------------------------------
def callback(self, topic, data):
""""""
self.engine.updatePos(data)

View File

@ -0,0 +1,199 @@
# encoding: UTF-8
import shelve
from vnpy.event import Event
from vnpy.trader.uiQt import QtCore, QtGui, QtWidgets
from vnpy.trader.vtFunction import getTempPath
from .tcEngine import EVENT_TC_LOG
########################################################################
class TcManager(QtWidgets.QWidget):
""""""
REQ_ADDRESS = 'tcp://localhost:2015'
SUB_ADDRESS = 'tcp://localhost:2018'
REP_ADDRESS = 'tcp://*:2015'
PUB_ADDRESS = 'tcp://*:2018'
COPY_RATIO = '1'
INTERVAL = '1'
settingFileName = 'TradeCopy.vt'
settingFilePath = getTempPath(settingFileName)
signal = QtCore.Signal(type(Event()))
#----------------------------------------------------------------------
def __init__(self, tcEngine, eventEngine, parent=None):
"""Constructor"""
super(TcManager, self).__init__(parent)
self.tcEngine = tcEngine
self.eventEngine = eventEngine
self.initUi()
self.loadSetting()
self.registerEvent()
self.tcEngine.writeLog(u'欢迎使用TradeCopy交易复制模块')
#----------------------------------------------------------------------
def initUi(self):
""""""
self.setWindowTitle(u'交易复制')
self.setMinimumWidth(700)
self.setMinimumHeight(700)
# 创建组件
self.lineReqAddress = QtWidgets.QLineEdit(self.REQ_ADDRESS)
self.lineSubAddress= QtWidgets.QLineEdit(self.SUB_ADDRESS)
self.lineRepAddress = QtWidgets.QLineEdit(self.REP_ADDRESS)
self.linePubAddress = QtWidgets.QLineEdit(self.PUB_ADDRESS)
validator = QtGui.QDoubleValidator()
validator.setBottom(0)
self.lineCopyRatio = QtWidgets.QLineEdit()
self.lineCopyRatio.setValidator(validator)
self.lineCopyRatio.setText(self.COPY_RATIO)
validator2 = QtGui.QIntValidator()
validator2.setBottom(1)
self.lineInterval = QtWidgets.QLineEdit()
self.lineInterval.setValidator(validator2)
self.lineInterval.setText(self.INTERVAL)
self.buttonProvider = QtWidgets.QPushButton(u'启动发布者')
self.buttonProvider.clicked.connect(self.startProvider)
self.buttonSubscriber = QtWidgets.QPushButton(u'启动订阅者')
self.buttonSubscriber.clicked.connect(self.startSubscriber)
self.buttonStopEngine = QtWidgets.QPushButton(u'停止')
self.buttonStopEngine.clicked.connect(self.stopEngine)
self.buttonStopEngine.setEnabled(False)
self.buttonResetAddress = QtWidgets.QPushButton(u'重置地址')
self.buttonResetAddress.clicked.connect(self.resetAddress)
self.logMonitor = QtWidgets.QTextEdit()
self.logMonitor.setReadOnly(True)
self.widgetList = [
self.lineCopyRatio,
self.lineInterval,
self.linePubAddress,
self.lineSubAddress,
self.lineRepAddress,
self.lineReqAddress,
self.buttonProvider,
self.buttonSubscriber,
self.buttonResetAddress
]
# 布局
QLabel = QtWidgets.QLabel
grid = QtWidgets.QGridLayout()
grid.addWidget(QLabel(u'响应地址'), 0, 0)
grid.addWidget(self.lineRepAddress, 0, 1)
grid.addWidget(QLabel(u'请求地址'), 0, 2)
grid.addWidget(self.lineReqAddress, 0, 3)
grid.addWidget(QLabel(u'发布地址'), 1, 0)
grid.addWidget(self.linePubAddress, 1, 1)
grid.addWidget(QLabel(u'订阅地址'), 1, 2)
grid.addWidget(self.lineSubAddress, 1, 3)
grid.addWidget(QLabel(u'发布间隔(秒)'), 2, 0)
grid.addWidget(self.lineInterval, 2, 1)
grid.addWidget(QLabel(u'复制比例(倍)'), 2, 2)
grid.addWidget(self.lineCopyRatio, 2, 3)
grid.addWidget(self.buttonProvider, 3, 0, 1, 2)
grid.addWidget(self.buttonSubscriber, 3, 2, 1, 2)
grid.addWidget(self.buttonStopEngine, 4, 0, 1, 2)
grid.addWidget(self.buttonResetAddress, 4, 2, 1, 2)
grid.addWidget(self.logMonitor, 5, 0, 1, 4)
self.setLayout(grid)
#----------------------------------------------------------------------
def saveSetting(self):
""""""
f = shelve.open(self.settingFilePath)
f['repAddress'] = self.lineRepAddress.text()
f['reqAddress'] = self.lineReqAddress.text()
f['pubAddress'] = self.linePubAddress.text()
f['subAddress'] = self.lineSubAddress.text()
f['copyRatio'] = self.lineCopyRatio.text()
f['interval'] = self.lineInterval.text()
f.close()
#----------------------------------------------------------------------
def loadSetting(self):
""""""
f = shelve.open(self.settingFilePath)
if f:
self.lineRepAddress.setText(f['repAddress'])
self.lineReqAddress.setText(f['reqAddress'])
self.linePubAddress.setText(f['pubAddress'])
self.lineSubAddress.setText(f['subAddress'])
self.lineCopyRatio.setText(f['copyRatio'])
self.lineInterval.setText(f['interval'])
f.close()
#----------------------------------------------------------------------
def resetAddress(self):
""""""
self.lineReqAddress.setText(self.REQ_ADDRESS)
self.lineRepAddress.setText(self.REP_ADDRESS)
self.linePubAddress.setText(self.PUB_ADDRESS)
self.lineSubAddress.setText(self.SUB_ADDRESS)
#----------------------------------------------------------------------
def stopEngine(self):
""""""
self.tcEngine.stop()
for widget in self.widgetList:
widget.setEnabled(True)
self.buttonStopEngine.setEnabled(False)
#----------------------------------------------------------------------
def registerEvent(self):
""""""
self.signal.connect(self.processLogEvent)
self.eventEngine.register(EVENT_TC_LOG, self.signal.emit)
#----------------------------------------------------------------------
def processLogEvent(self, event):
""""""
log = event.dict_['data']
txt = '%s: %s' %(log.logTime, log.logContent)
self.logMonitor.append(txt)
#----------------------------------------------------------------------
def startProvider(self):
""""""
repAddress = str(self.lineRepAddress.text())
pubAddress = str(self.linePubAddress.text())
interval = int(self.lineInterval.text())
self.tcEngine.startProvider(repAddress, pubAddress, interval)
for widget in self.widgetList:
widget.setEnabled(False)
self.buttonStopEngine.setEnabled(True)
#----------------------------------------------------------------------
def startSubscriber(self):
""""""
reqAddress = str(self.lineReqAddress.text())
subAddress = str(self.lineSubAddress.text())
copyRatio = float(self.lineCopyRatio.text())
self.tcEngine.startSubscriber(reqAddress, subAddress, copyRatio)
for widget in self.widgetList:
widget.setEnabled(False)
self.buttonStopEngine.setEnabled(True)