diff --git a/examples/vn_trader/run.py b/examples/vn_trader/run.py index 5dcbe65d..187b5abf 100644 --- a/examples/vn_trader/run.py +++ b/examples/vn_trader/run.py @@ -5,9 +5,9 @@ from vnpy.trader.engine import MainEngine from vnpy.trader.ui import MainWindow, create_qapp # from vnpy.gateway.binance import BinanceGateway -# from vnpy.gateway.bitmex import BitmexGateway +from vnpy.gateway.bitmex import BitmexGateway # from vnpy.gateway.futu import FutuGateway -from vnpy.gateway.ib import IbGateway +# from vnpy.gateway.ib import IbGateway # from vnpy.gateway.ctp import CtpGateway # from vnpy.gateway.ctptest import CtptestGateway # from vnpy.gateway.femas import FemasGateway @@ -30,6 +30,7 @@ from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.cta_backtester import CtaBacktesterApp # from vnpy.app.data_recorder import DataRecorderApp # from vnpy.app.risk_manager import RiskManagerApp +from vnpy.app.script_trader import ScriptTraderApp def main(): @@ -44,9 +45,9 @@ def main(): # main_engine.add_gateway(CtpGateway) # main_engine.add_gateway(CtptestGateway) # main_engine.add_gateway(FemasGateway) - main_engine.add_gateway(IbGateway) + # main_engine.add_gateway(IbGateway) # main_engine.add_gateway(FutuGateway) - # main_engine.add_gateway(BitmexGateway) + main_engine.add_gateway(BitmexGateway) # main_engine.add_gateway(TigerGateway) # main_engine.add_gateway(OesGateway) # main_engine.add_gateway(OkexGateway) @@ -66,6 +67,7 @@ def main(): # main_engine.add_app(AlgoTradingApp) # main_engine.add_app(DataRecorderApp) # main_engine.add_app(RiskManagerApp) + main_engine.add_app(ScriptTraderApp) main_window = MainWindow(main_engine, event_engine) main_window.showMaximized() diff --git a/vnpy/app/script_trader/__init__.py b/vnpy/app/script_trader/__init__.py new file mode 100644 index 00000000..c54b69f8 --- /dev/null +++ b/vnpy/app/script_trader/__init__.py @@ -0,0 +1,14 @@ +from pathlib import Path +from vnpy.trader.app import BaseApp +from .engine import ScriptEngine, APP_NAME + + +class ScriptTraderApp(BaseApp): + """""" + app_name = APP_NAME + app_module = __module__ + app_path = Path(__file__).parent + display_name = "脚本策略" + engine_class = ScriptEngine + widget_name = "ScriptManager" + icon_name = "script.ico" diff --git a/vnpy/app/script_trader/engine.py b/vnpy/app/script_trader/engine.py new file mode 100644 index 00000000..45642c24 --- /dev/null +++ b/vnpy/app/script_trader/engine.py @@ -0,0 +1,262 @@ +"""""" + +import sys +import importlib +import traceback +from typing import Sequence +from pathlib import Path +from datetime import datetime +from threading import Thread + +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.trader.constant import Direction, Offset, OrderType, Interval +from vnpy.trader.object import ( + OrderRequest, + HistoryRequest, + SubscribeRequest, + TickData, + OrderData, + TradeData, + PositionData, + AccountData, + ContractData, + LogData, + BarData +) +from vnpy.trader.rqdata import rqdata_client + + +APP_NAME = "ScriptTrader" + +EVENT_SCRIPT_LOG = "eScriptLog" + + +class ScriptEngine(BaseEngine): + """""" + setting_filename = "script_trader_setting.json" + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__(main_engine, event_engine, APP_NAME) + + self.main_engine = main_engine + self.event_engine = event_engine + + self.get_tick = main_engine.get_tick + self.get_order = main_engine.get_order + self.get_trade = main_engine.get_trade + self.get_position = main_engine.get_position + self.get_account = main_engine.get_account + self.get_contract = main_engine.get_contract + self.get_all_ticks = main_engine.get_all_ticks + self.get_all_orders = main_engine.get_all_orders + self.get_all_trades = main_engine.get_all_trades + self.get_all_positions = main_engine.get_all_positions + self.get_all_accounts = main_engine.get_all_accounts + self.get_all_contracts = main_engine.get_all_contracts + self.get_all_active_orders = main_engine.get_all_active_orders + + self.strategy_active = False + self.strategy_thread = None + + def init(self): + """ + Start script engine. + """ + result = rqdata_client.init() + if result: + self.write_log("RQData数据接口初始化成功") + + def start_strategy(self, script_path: str): + """""" + if self.strategy_active: + return + self.strategy_active = True + + self.strategy_thread = Thread( + target=self.run_strategy, args=(script_path,)) + self.strategy_thread.start() + + self.write_log("策略交易脚本启动") + + def run_strategy(self, script_path: str): + """""" + path = Path(script_path) + sys.path.append(str(path.parent)) + + script_name = path.parts[-1] + module_name = script_name.replace(".py", "") + + try: + module = importlib.import_module(module_name) + importlib.reload(module) + module.run(self) + except: # noqa + msg = f"触发异常已停止\n{traceback.format_exc()}" + self.write_log(msg) + + def stop_strategy(self): + """""" + if not self.strategy_active: + return + self.strategy_active = False + + if self.strategy_thread: + self.strategy_thread.join() + self.strategy_thread = None + + self.write_log("策略交易脚本停止") + + def send_order( + self, + vt_symbol: str, + price: float, + volume: float, + direction: Direction, + offset: Offset, + order_type: OrderType + ) -> str: + """""" + contract = self.get_contract(vt_symbol) + if not contract: + return "" + + req = OrderRequest( + symbol=contract.symbol, + exchange=contract.exchange, + type=order_type, + volume=volume, + price=price, + offset=offset + ) + + vt_orderid = self.main_engine.send_order(req, contract.gateway_name) + return vt_orderid + + def subscribe(self, vt_symbols): + """""" + for vt_symbol in vt_symbols: + contract = self.main_engine.get_contract(vt_symbol) + if contract: + req = SubscribeRequest( + symbol=contract.symbol, + exchange=contract.exchange + ) + self.main_engine.subscribe(req, contract.gateway_name) + + def buy(self, vt_symbol: str, price: str, volume: str, order_type: OrderType = OrderType.LIMIT) -> str: + """""" + return self.send_order(vt_symbol, price, volume, Direction.LONG, Offset.OPEN, order_type) + + def sell(self, vt_symbol: str, price: str, volume: str, order_type: OrderType = OrderType.LIMIT) -> str: + """""" + return self.send_order(vt_symbol, price, volume, Direction.SHORT, Offset.CLOSE, order_type) + + def short(self, vt_symbol: str, price: str, volume: str, order_type: OrderType = OrderType.LIMIT) -> str: + """""" + return self.send_order(vt_symbol, price, volume, Direction.SHORT, Offset.OPEN, order_type) + + def cover(self, vt_symbol: str, price: str, volume: str, order_type: OrderType = OrderType.LIMIT) -> str: + """""" + return self.send_order(vt_symbol, price, volume, Direction.LONG, Offset.CLOSE, order_type) + + def cancel_order(self, vt_orderid: str) -> None: + """""" + order = self.get_order(vt_orderid) + if not order: + return + + req = order.create_cancel_request() + self.main_engine.cancel_order(req, order.gateway_name) + + def get_tick(self, vt_symbol: str) -> TickData: + """""" + return self.main_engine.get_tick(vt_symbol) + + def get_ticks(self, vt_symbols: Sequence[str]) -> Sequence[TickData]: + """""" + ticks = [] + for vt_symbol in vt_symbols: + tick = self.main_engine.get_tick(vt_symbol) + ticks.append(tick) + return ticks + + def get_order(self, vt_orderid: str) -> OrderData: + """""" + return self.main_engine.get_order(vt_orderid) + + def get_orders(self, vt_orderids: Sequence[str]) -> Sequence[OrderData]: + """""" + orders = [] + for vt_orderid in vt_orderids: + order = self.main_engine.get_order(vt_orderid) + orders.append(order) + return orders + + def get_trades(self, vt_orderid: str) -> Sequence[TradeData]: + """""" + trades = [] + all_trades = self.main_engine.get_all_trades() + + for trade in all_trades: + if trade.vt_orderid == vt_orderid: + trades.append(trade) + + return trades + + def get_all_active_orders(self) -> Sequence[OrderData]: + """""" + return self.main_engine.get_all_active_orders() + + def get_contract(self, vt_symbol) -> ContractData: + """""" + return self.main_engine.get_contract(vt_symbol) + + def get_all_contracts(self) -> Sequence[ContractData]: + """""" + return self.main_engine.get_all_contracts() + + def get_account(self, vt_accountid: str) -> AccountData: + """""" + return self.main_engine.get_account(vt_accountid) + + def get_all_accounts(self) -> Sequence[AccountData]: + """""" + return self.main_engine.get_all_accounts() + + def get_position(self, vt_positionid: str) -> PositionData: + """""" + return self.main_engine.get_position(vt_positionid) + + def get_all_positions(self) -> Sequence[AccountData]: + """""" + return self.main_engine.get_all_positions() + + def get_bars(self, vt_symbol: str, start_date: str, interval: Interval) -> Sequence[BarData]: + """""" + contract = self.main_engine.get_contract(vt_symbol) + if not contract: + return [] + + start = datetime.strptime(start_date, "%Y%m%d") + + req = HistoryRequest( + symbol=contract.symbol, + exchange=contract.exchange, + start=start, + interval=interval + ) + + bars = rqdata_client.query_history(req) + if not bars: + return [] + return bars + + def write_log(self, msg: str) -> None: + """""" + log = LogData(msg=msg, gateway_name=APP_NAME) + print(f"{log.time}\t{log.msg}") + + event = Event(EVENT_SCRIPT_LOG, log) + self.event_engine.put(event) diff --git a/vnpy/app/script_trader/ui/__init__.py b/vnpy/app/script_trader/ui/__init__.py new file mode 100644 index 00000000..53442534 --- /dev/null +++ b/vnpy/app/script_trader/ui/__init__.py @@ -0,0 +1 @@ +from .widget import ScriptManager diff --git a/vnpy/app/script_trader/ui/script.ico b/vnpy/app/script_trader/ui/script.ico new file mode 100644 index 00000000..06fc921f Binary files /dev/null and b/vnpy/app/script_trader/ui/script.ico differ diff --git a/vnpy/app/script_trader/ui/widget.py b/vnpy/app/script_trader/ui/widget.py new file mode 100644 index 00000000..45358180 --- /dev/null +++ b/vnpy/app/script_trader/ui/widget.py @@ -0,0 +1,97 @@ +from pathlib import Path + +from vnpy.event import EventEngine, Event +from vnpy.trader.engine import MainEngine +from vnpy.trader.ui import QtWidgets, QtCore +from ..engine import APP_NAME, EVENT_SCRIPT_LOG + + +class ScriptManager(QtWidgets.QWidget): + """""" + signal_log = QtCore.pyqtSignal(Event) + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__() + + self.main_engine = main_engine + self.event_engine = event_engine + + self.script_engine = main_engine.get_engine(APP_NAME) + + self.script_path = "" + + self.init_ui() + self.register_event() + + self.script_engine.init() + + def init_ui(self): + """""" + self.setWindowTitle("脚本策略") + + start_button = QtWidgets.QPushButton("启动") + start_button.clicked.connect(self.start_script) + + stop_button = QtWidgets.QPushButton("停止") + stop_button.clicked.connect(self.stop_script) + + select_button = QtWidgets.QPushButton("打开") + select_button.clicked.connect(self.select_script) + + self.strategy_line = QtWidgets.QLineEdit() + + self.log_monitor = QtWidgets.QTextEdit() + self.log_monitor.setReadOnly(True) + + hbox = QtWidgets.QHBoxLayout() + hbox.addWidget(self.strategy_line) + hbox.addWidget(select_button) + hbox.addWidget(start_button) + hbox.addWidget(stop_button) + + vbox = QtWidgets.QVBoxLayout() + vbox.addLayout(hbox) + vbox.addWidget(self.log_monitor) + + self.setLayout(vbox) + + def register_event(self): + """""" + self.signal_log.connect(self.process_log_event) + + self.event_engine.register(EVENT_SCRIPT_LOG, self.signal_log.emit) + + def show(self): + """""" + self.showMaximized() + + def process_log_event(self, event: Event): + """""" + log = event.data + msg = f"{log.time}\t{log.msg}" + self.log_monitor.append(msg) + + def start_script(self): + """""" + if self.script_path: + self.script_engine.start_strategy(self.script_path) + + def stop_script(self): + """""" + self.script_engine.stop_strategy() + + def select_script(self): + """""" + cwd = str(Path.cwd()) + + path, type_ = QtWidgets.QFileDialog.getOpenFileName( + self, + u"载入策略脚本", + cwd, + "Python File(*.py)" + ) + + if path: + self.script_path = path + self.strategy_line.setText(path) diff --git a/vnpy/gateway/onetoken/onetoken_gateway.py b/vnpy/gateway/onetoken/onetoken_gateway.py index e140c37f..e6af5faa 100644 --- a/vnpy/gateway/onetoken/onetoken_gateway.py +++ b/vnpy/gateway/onetoken/onetoken_gateway.py @@ -619,7 +619,7 @@ class OnetokenTradeWebsocketApi(WebsocketClient): elif _type == "future": long_position = PositionData( symbol=account_data["contract"], - exchange=Exchange(self.exchange.upper()), + exchange=Exchange(self.exchange.upper()), direction=Direction.LONG, price=account_data["average_open_price_long"], volume=account_data["total_amount_long"], @@ -629,7 +629,7 @@ class OnetokenTradeWebsocketApi(WebsocketClient): ) short_position = PositionData( symbol=account_data["contract"], - exchange=Exchange(self.exchange.upper()), + exchange=Exchange(self.exchange.upper()), direction=Direction.SHORT, price=account_data["average_open_price_short"], volume=account_data["total_amount_short"], @@ -661,7 +661,7 @@ class OnetokenTradeWebsocketApi(WebsocketClient): gateway_name=self.gateway_name ) - if order_data["status"] in ("withdrawn","part-deal-withdrawn"): + if order_data["status"] in ("withdrawn", "part-deal-withdrawn"): order.status = Status.CANCELLED else: if order.traded == order.volume: diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index 3aa882b0..79fa1634 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -3,6 +3,7 @@ import logging import smtplib +import os from abc import ABC from datetime import datetime from email.message import EmailMessage @@ -30,7 +31,7 @@ from .object import ( HistoryRequest ) from .setting import SETTINGS -from .utility import get_folder_path +from .utility import get_folder_path, TRADER_DIR class MainEngine: @@ -51,7 +52,8 @@ class MainEngine: self.apps = {} self.exchanges = [] - self.init_engines() + os.chdir(TRADER_DIR) # Change working directory + self.init_engines() # Initialize function engines def add_engine(self, engine_class: Any): """