[Add] BasicSpreadStrategy for demo

This commit is contained in:
vn.py 2019-09-17 22:35:40 +08:00
parent 80d89d1cb8
commit 6a5d04e61f
7 changed files with 371 additions and 48 deletions

View File

@ -75,7 +75,7 @@ def main():
# main_engine.add_gateway(DaGateway)
main_engine.add_gateway(CoinbaseGateway)
# main_engine.add_app(CtaStrategyApp)
main_engine.add_app(CtaStrategyApp)
# main_engine.add_app(CtaBacktesterApp)
# main_engine.add_app(CsvLoaderApp)
# main_engine.add_app(AlgoTradingApp)

View File

@ -1,8 +1,19 @@
from pathlib import Path
from vnpy.trader.app import BaseApp
from vnpy.trader.object import (
OrderData,
TradeData
)
from .engine import SpreadEngine, APP_NAME
from .engine import (
SpreadEngine,
APP_NAME,
SpreadData,
LegData,
SpreadStrategyTemplate,
SpreadAlgoTemplate
)
class SpreadTradingApp(BaseApp):

View File

@ -48,6 +48,8 @@ class SpreadEngine(BaseEngine):
self.add_spread = self.data_engine.add_spread
self.remove_spread = self.data_engine.remove_spread
self.get_spread = self.data_engine.get_spread
self.get_all_spreads = self.data_engine.get_all_spreads
self.start_algo = self.algo_engine.start_algo
self.stop_algo = self.algo_engine.stop_algo
@ -60,6 +62,13 @@ class SpreadEngine(BaseEngine):
self.data_engine.start()
self.algo_engine.start()
self.strategy_engine.start()
def stop(self):
""""""
self.data_engine.stop()
self.algo_engine.stop()
self.strategy_engine.stop()
def write_log(self, msg: str):
""""""
@ -94,6 +103,10 @@ class SpreadDataEngine:
self.write_log("价差数据引擎启动成功")
def stop(self):
""""""
pass
def load_setting(self) -> None:
""""""
setting = load_json(self.setting_filename)
@ -251,11 +264,20 @@ class SpreadDataEngine:
spread = self.spreads.pop(name)
for leg in spread.legs:
for leg in spread.legs.values():
self.symbol_spread_map[leg.vt_symbol].remove(spread)
self.save_setting()
self.write_log("价差删除成功:{}".format(name))
self.write_log("价差移除成功:{},重启后生效".format(name))
def get_spread(self, name: str) -> SpreadData:
""""""
spread = self.spreads.get(name, None)
return spread
def get_all_spreads(self) -> List[SpreadData]:
""""""
return list(self.spreads.values())
class SpreadAlgoEngine:
@ -289,6 +311,11 @@ class SpreadAlgoEngine:
self.write_log("价差算法引擎启动成功")
def stop(self):
""""""
for algo in self.algos.values():
self.stop_algo(algo)
def register_event(self):
""""""
self.event_engine.register(EVENT_TICK, self.process_tick_event)
@ -533,9 +560,10 @@ class SpreadStrategyEngine:
self.vt_tradeids: Set = set()
self.load_strategy_class()
def start(self):
""""""
self.load_strategy_class()
self.load_strategy_setting()
self.register_event()
@ -551,7 +579,7 @@ class SpreadStrategyEngine:
"""
path1 = Path(__file__).parent.joinpath("strategies")
self.load_strategy_class_from_folder(
path1, "vnpy.app.cta_strategy.strategies")
path1, "vnpy.app.spread_trading.strategies")
path2 = Path.cwd().joinpath("strategies")
self.load_strategy_class_from_folder(path2, "strategies")
@ -642,7 +670,8 @@ class SpreadStrategyEngine:
strategies = self.spread_strategy_map[spread.name]
for strategy in strategies:
self.call_strategy_func(strategy, strategy.on_spread_data)
if strategy.inited:
self.call_strategy_func(strategy, strategy.on_spread_data)
def process_spread_pos_event(self, event: Event):
""""""
@ -650,7 +679,8 @@ class SpreadStrategyEngine:
strategies = self.spread_strategy_map[spread.name]
for strategy in strategies:
self.call_strategy_func(strategy, strategy.on_spread_pos)
if strategy.inited:
self.call_strategy_func(strategy, strategy.on_spread_pos)
def process_spread_algo_event(self, event: Event):
""""""
@ -671,7 +701,7 @@ class SpreadStrategyEngine:
def process_trade_event(self, event: Event):
""""""
trade = event.data
strategy = self.trade_strategy_map.get(trade.vt_orderid, None)
strategy = self.order_strategy_map.get(trade.vt_orderid, None)
if strategy:
self.call_strategy_func(strategy, strategy.on_trade, trade)
@ -692,7 +722,7 @@ class SpreadStrategyEngine:
strategy.inited = False
msg = f"触发异常已停止\n{traceback.format_exc()}"
self.write_log(msg, strategy)
self.write_strategy_log(strategy, msg)
def add_strategy(
self, class_name: str, strategy_name: str, spread_name: str, setting: dict
@ -709,7 +739,12 @@ class SpreadStrategyEngine:
self.write_log(f"创建策略失败,找不到策略类{class_name}")
return
strategy = strategy_class(self, strategy_name, spread_name, setting)
spread = self.spread_engine.get_spread(spread_name)
if not spread:
self.write_log(f"创建策略失败,找不到价差{spread_name}")
return
strategy = strategy_class(self, strategy_name, spread, setting)
self.strategies[strategy_name] = strategy
# Add vt_symbol to strategy map.
@ -721,6 +756,37 @@ class SpreadStrategyEngine:
self.put_strategy_event(strategy)
def edit_strategy(self, strategy_name: str, setting: dict):
"""
Edit parameters of a strategy.
"""
strategy = self.strategies[strategy_name]
strategy.update_setting(setting)
self.update_strategy_setting(strategy_name, setting)
self.put_strategy_event(strategy)
def remove_strategy(self, strategy_name: str):
"""
Remove a strategy.
"""
strategy = self.strategies[strategy_name]
if strategy.trading:
self.write_log(f"策略{strategy.strategy_name}移除失败,请先停止")
return
# Remove setting
self.remove_strategy_setting(strategy_name)
# Remove from symbol strategy map
strategies = self.spread_strategy_map[strategy.spread_name]
strategies.remove(strategy)
# Remove from strategies
self.strategies.pop(strategy_name)
return True
def init_strategy(self, strategy_name: str):
""""""
strategy = self.strategies[strategy_name]
@ -758,28 +824,48 @@ class SpreadStrategyEngine:
return
self.call_strategy_func(strategy, strategy.on_stop)
strategy.trading = False
strategy.stop_all_algos()
strategy.cancel_all_orders()
strategy.trading = False
self.put_strategy_event(strategy)
def init_all_strategies(self):
""""""
for strategy in self.strategies.values():
for strategy in self.strategies.keys():
self.init_strategy(strategy)
def start_all_strategies(self):
""""""
for strategy in self.strategies.values():
for strategy in self.strategies.keys():
self.start_strategy(strategy)
def stop_all_strategies(self):
""""""
for strategy in self.strategies.values():
for strategy in self.strategies.keys():
self.stop_strategy(strategy)
def get_strategy_class_parameters(self, class_name: str):
"""
Get default parameters of a strategy class.
"""
strategy_class = self.classes[class_name]
parameters = {}
for name in strategy_class.parameters:
parameters[name] = getattr(strategy_class, name)
return parameters
def get_strategy_parameters(self, strategy_name):
"""
Get parameters of a strategy.
"""
strategy = self.strategies[strategy_name]
return strategy.get_parameters()
def start_algo(
self,
strategy: SpreadStrategyTemplate,
@ -864,7 +950,8 @@ class SpreadStrategyEngine:
""""""
order = self.main_engine.get_order(vt_orderid)
if not order:
self.write_strategy_log(strategy, "撤单失败,找不到委托{}".format(vt_orderid))
self.write_strategy_log(
strategy, "撤单失败,找不到委托{}".format(vt_orderid))
return
req = order.create_cancel_request()
@ -876,7 +963,8 @@ class SpreadStrategyEngine:
def put_strategy_event(self, strategy: SpreadStrategyTemplate):
""""""
event = Event(EVENT_SPREAD_STRATEGY, strategy)
data = strategy.get_data()
event = Event(EVENT_SPREAD_STRATEGY, data)
self.event_engine.put(event)
def write_strategy_log(self, strategy: SpreadStrategyTemplate, msg: str):

View File

@ -0,0 +1,168 @@
from vnpy.app.spread_trading import (
SpreadStrategyTemplate,
SpreadAlgoTemplate,
SpreadData,
OrderData,
TradeData
)
class BasicSpreadStrategy(SpreadStrategyTemplate):
""""""
author = "用Python的交易员"
buy_price = 0.0
sell_price = 0.0
cover_price = 0.0
short_price = 0.0
max_pos = 0.0
payup = 10
interval = 5
spread_pos = 0.0
buy_algoid = ""
sell_algoid = ""
short_algoid = ""
cover_algoid = ""
parameters = [
"buy_price",
"sell_price",
"cover_price",
"short_price",
"max_pos",
"payup",
"interval"
]
variables = [
"spread_pos",
"buy_algoid",
"sell_algoid",
"short_algoid",
"cover_algoid",
]
def __init__(
self,
strategy_engine,
strategy_name: str,
spread: SpreadData,
setting: dict
):
""""""
super().__init__(
strategy_engine, strategy_name, spread, setting
)
def on_init(self):
"""
Callback when strategy is inited.
"""
self.write_log("策略初始化")
def on_start(self):
"""
Callback when strategy is started.
"""
self.write_log("策略启动")
def on_stop(self):
"""
Callback when strategy is stopped.
"""
self.write_log("策略停止")
self.buy_algoid = ""
self.sell_algoid = ""
self.short_algoid = ""
self.cover_algoid = ""
self.put_event()
def on_spread_data(self):
"""
Callback when spread price is updated.
"""
self.spread_pos = self.get_spread_pos()
# No position
if not self.spread_pos:
# Start open algos
if not self.buy_algoid:
self.buy_algoid = self.start_long_algo(
self.buy_price, self.max_pos, self.payup, self.interval
)
if not self.short_algoid:
self.short_algoid = self.start_short_algo(
self.short_price, self.max_pos, self.payup, self.interval
)
# Stop close algos
if self.sell_algoid:
self.stop_algo(self.sell_algoid)
if self.cover_algoid:
self.stop_algo(self.cover_algoid)
# Long position
elif self.spread_pos > 0:
# Start sell close algo
if not self.sell_algoid:
self.sell_algoid = self.start_short_algo(
self.sell_price, self.spread_pos, self.payup, self.interval
)
# Stop buy open algo
if self.buy_algoid:
self.stop_algo(self.buy_algoid)
# Short position
elif self.spread_pos < 0:
# Start cover close algo
if not self.cover_algoid:
self.cover_algoid = self.start_long_algo(
self.cover_price, abs(
self.spread_pos), self.payup, self.interval
)
# Stop short open algo
if self.short_algoid:
self.stop_algo(self.short_algoid)
self.put_event()
def on_spread_pos(self):
"""
Callback when spread position is updated.
"""
self.spread_pos = self.get_spread_pos()
self.put_event()
def on_spread_algo(self, algo: SpreadAlgoTemplate):
"""
Callback when algo status is updated.
"""
if not algo.is_active():
if self.buy_algoid == algo.algoid:
self.buy_algoid = ""
elif self.sell_algoid == algo.algoid:
self.sell_algoid = ""
elif self.short_algoid == algo.algoid:
self.short_algoid = ""
else:
self.cover_algoid = ""
self.put_event()
def on_order(self, order: OrderData):
"""
Callback when order status is updated.
"""
pass
def on_trade(self, trade: TradeData):
"""
Callback when new trade data is received.
"""
pass

View File

@ -147,11 +147,11 @@ class SpreadAlgoTemplate:
def update_timer(self):
""""""
self.count += 1
if self.count < self.interval:
return
self.count = 0
if self.count > self.interval:
self.count = 0
self.on_interval()
self.on_interval()
self.put_event()
def put_event(self):
""""""
@ -358,7 +358,7 @@ class SpreadStrategyTemplate:
Callback when algo status is updated.
"""
if not algo.is_active() and algo.algoid in self.algoids:
self.algoids.pop(algo.algoid)
self.algoids.remove(algo.algoid)
self.on_spread_algo(algo)
@ -367,7 +367,7 @@ class SpreadStrategyTemplate:
Callback when order status is updated.
"""
if not order.is_active() and order.vt_orderid in self.vt_orderids:
self.vt_orderids.pop(order.vt_orderid)
self.vt_orderids.remove(order.vt_orderid)
self.on_order(order)
@ -461,7 +461,7 @@ class SpreadStrategyTemplate:
volume: float,
payup: int,
interval: int,
lock: bool
lock: bool = False
) -> str:
""""""
return self.start_algo(Direction.LONG, price, volume, payup, interval, lock)
@ -472,7 +472,7 @@ class SpreadStrategyTemplate:
volume: float,
payup: int,
interval: int,
lock: bool
lock: bool = False
) -> str:
""""""
return self.start_algo(Direction.SHORT, price, volume, payup, interval, lock)

View File

@ -34,6 +34,7 @@ class SpreadManager(QtWidgets.QWidget):
self.main_engine = main_engine
self.event_engine = event_engine
self.spread_engine = main_engine.get_engine(APP_NAME)
self.init_ui()
@ -43,8 +44,8 @@ class SpreadManager(QtWidgets.QWidget):
self.setWindowTitle("价差交易")
self.algo_dialog = SpreadAlgoWidget(self.spread_engine)
algo_tab = self.create_tab("交易", self.algo_dialog)
algo_tab.setMaximumWidth(300)
algo_group = self.create_group("交易", self.algo_dialog)
algo_group.setMaximumWidth(300)
self.data_monitor = SpreadDataMonitor(
self.main_engine,
@ -63,13 +64,13 @@ class SpreadManager(QtWidgets.QWidget):
)
grid = QtWidgets.QGridLayout()
grid.addWidget(self.create_tab("价差", self.data_monitor), 0, 0)
grid.addWidget(self.create_tab("日志", self.log_monitor), 1, 0)
grid.addWidget(self.create_tab("算法", self.algo_monitor), 0, 1)
grid.addWidget(self.create_tab("策略", self.strategy_monitor), 1, 1)
grid.addWidget(self.create_group("价差", self.data_monitor), 0, 0)
grid.addWidget(self.create_group("日志", self.log_monitor), 1, 0)
grid.addWidget(self.create_group("算法", self.algo_monitor), 0, 1)
grid.addWidget(self.create_group("策略", self.strategy_monitor), 1, 1)
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(algo_tab)
hbox.addWidget(algo_group)
hbox.addLayout(grid)
self.setLayout(hbox)
@ -77,14 +78,20 @@ class SpreadManager(QtWidgets.QWidget):
def show(self):
""""""
self.spread_engine.start()
self.algo_dialog.update_class_combo()
self.showMaximized()
def create_tab(self, title: str, widget: QtWidgets.QWidget):
def create_group(self, title: str, widget: QtWidgets.QWidget):
""""""
tab = QtWidgets.QTabWidget()
tab.addTab(widget, title)
return tab
group = QtWidgets.QGroupBox()
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(widget)
group.setLayout(vbox)
group.setTitle(title)
return group
class SpreadDataMonitor(BaseMonitor):
@ -145,7 +152,7 @@ class SpreadLogMonitor(QtWidgets.QTextEdit):
def process_log_event(self, event: Event):
""""""
log = event.data
msg = f"{log.time.strftime('%H:%M:%S')}{log.msg}"
msg = f"{log.time.strftime('%H:%M:%S')}\t{log.msg}"
self.append(msg)
@ -205,7 +212,6 @@ class SpreadAlgoWidget(QtWidgets.QFrame):
self.strategy_engine: SpreadStrategyEngine = spread_engine.strategy_engine
self.init_ui()
self.update_class_combo()
def init_ui(self):
""""""
@ -314,7 +320,8 @@ class SpreadAlgoWidget(QtWidgets.QFrame):
def remove_spread(self):
""""""
pass
dialog = SpreadRemoveDialog(self.spread_engine)
dialog.exec_()
def update_class_combo(self):
""""""
@ -470,7 +477,44 @@ class SpreadDataDialog(QtWidgets.QDialog):
self.accept()
class SpreadStrategyMonitor(QtWidgets.QScrollArea):
class SpreadRemoveDialog(QtWidgets.QDialog):
""""""
def __init__(self, spread_engine: SpreadEngine):
""""""
super().__init__()
self.spread_engine: SpreadEngine = spread_engine
self.init_ui()
def init_ui(self):
""""""
self.setWindowTitle("移除价差")
self.setMinimumWidth(300)
self.name_combo = QtWidgets.QComboBox()
spreads = self.spread_engine.get_all_spreads()
for spread in spreads:
self.name_combo.addItem(spread.name)
button_remove = QtWidgets.QPushButton("移除")
button_remove.clicked.connect(self.remove_spread)
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(self.name_combo)
hbox.addWidget(button_remove)
self.setLayout(hbox)
def remove_spread(self):
""""""
spread_name = self.name_combo.currentText()
self.spread_engine.remove_spread(spread_name)
self.accept()
class SpreadStrategyMonitor(QtWidgets.QWidget):
""""""
signal_strategy = QtCore.pyqtSignal(Event)
@ -495,8 +539,13 @@ class SpreadStrategyMonitor(QtWidgets.QScrollArea):
scroll_widget = QtWidgets.QWidget()
scroll_widget.setLayout(self.scroll_layout)
self.setWidgetResizable(True)
self.setWidget(scroll_widget)
scroll_area = QtWidgets.QScrollArea()
scroll_area.setWidgetResizable(True)
scroll_area.setWidget(scroll_widget)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(scroll_area)
self.setLayout(vbox)
def register_event(self):
""""""
@ -517,10 +566,15 @@ class SpreadStrategyMonitor(QtWidgets.QScrollArea):
manager = self.managers[strategy_name]
manager.update_data(data)
else:
manager = SpreadStrategyWidget(self.strategy_engine, data)
manager = SpreadStrategyWidget(self, self.strategy_engine, data)
self.scroll_layout.insertWidget(0, manager)
self.managers[strategy_name] = manager
def remove_strategy(self, strategy_name):
""""""
manager = self.managers.pop(strategy_name)
manager.deleteLater()
class SpreadStrategyWidget(QtWidgets.QFrame):
"""
@ -529,12 +583,14 @@ class SpreadStrategyWidget(QtWidgets.QFrame):
def __init__(
self,
strategy_monitor: SpreadStrategyMonitor,
strategy_engine: SpreadStrategyEngine,
data: dict
):
""""""
super().__init__()
self.strategy_monitor = strategy_monitor
self.strategy_engine = strategy_engine
self.strategy_name = data["strategy_name"]
@ -629,7 +685,7 @@ class SpreadStrategyWidget(QtWidgets.QFrame):
# Only remove strategy gui manager if it has been removed from engine
if result:
self.spread_manager.remove_strategy(self.strategy_name)
self.strategy_monitor.remove_strategy(self.strategy_name)
class StrategyDataMonitor(QtWidgets.QTableWidget):
@ -698,11 +754,11 @@ class SettingEditor(QtWidgets.QDialog):
""""""
form = QtWidgets.QFormLayout()
# Add vt_symbol and name edit if add new strategy
# Add spread_name and name edit if add new strategy
if self.class_name:
self.setWindowTitle(f"添加策略:{self.class_name}")
button_text = "添加"
parameters = {"strategy_name": "", "vt_symbol": ""}
parameters = {"strategy_name": "", "spread_name": ""}
parameters.update(self.parameters)
else:
self.setWindowTitle(f"参数编辑:{self.strategy_name}")