diff --git a/requirements.txt b/requirements.txt index 1257ab49..fa2f5a18 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,8 @@ PyQt5 qdarkstyle futu-api websocket-client -peewee \ No newline at end of file +peewee +numpy +pandas +matplotlib +seaborn \ No newline at end of file diff --git a/tests/test_all.py b/tests/test_all.py index f636d3fd..fba0404a 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,4 +1,4 @@ import unittest -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tools/ci/format_check.py b/tools/ci/format_check.py index c1455f19..b117a84e 100644 --- a/tools/ci/format_check.py +++ b/tools/ci/format_check.py @@ -5,20 +5,21 @@ from yapf.yapflib.yapf_api import FormatFile logger = logging.Logger(__file__) -if __name__ == '__main__': +if __name__ == "__main__": has_changed = False for root, dir, filenames in os.walk("vnpy"): for filename in filenames: basename, ext = os.path.splitext(filename) - if ext == '.py': + if ext == ".py": path = os.path.join(root, filename) - reformatted_code, encoding, changed = FormatFile(filename=path, - style_config='.style.yapf', - print_diff=True, - verify=False, - in_place=False, - logger=None - ) + reformatted_code, encoding, changed = FormatFile( + filename=path, + style_config=".style.yapf", + print_diff=True, + verify=False, + in_place=False, + logger=None, + ) if changed: has_changed = True logger.warning("File {} not formatted!".format(path)) diff --git a/vnpy/api/rest/rest_client.py b/vnpy/api/rest/rest_client.py index 967f9f59..aabc1ed7 100644 --- a/vnpy/api/rest/rest_client.py +++ b/vnpy/api/rest/rest_client.py @@ -12,10 +12,10 @@ from typing import Any, Callable, Optional class RequestStatus(Enum): - ready = 0 # Request created - success = 1 # Request successful (status code 2xx) + ready = 0 # Request created + success = 1 # Request successful (status code 2xx) failed = 2 # Request failed (status code not 2xx) - error = 3 # Exception raised + error = 3 # Exception raised class Request(object): @@ -24,16 +24,16 @@ class Request(object): """ def __init__( - self, - method: str, - path: str, - params: dict, - data: dict, - headers: dict, - callback: Callable, - on_failed: Callable = None, - on_error: Callable = None, - extra: Any = None + self, + method: str, + path: str, + params: dict, + data: dict, + headers: dict, + callback: Callable, + on_failed: Callable = None, + on_error: Callable = None, + extra: Any = None, ): """""" self.method = method @@ -52,7 +52,7 @@ class Request(object): def __str__(self): if self.response is None: - status_code = 'terminated' + status_code = "terminated" else: status_code = self.response.status_code @@ -70,7 +70,7 @@ class Request(object): self.headers, self.params, self.data, - '' if self.response is None else self.response.text + "" if self.response is None else self.response.text, ) ) @@ -88,11 +88,11 @@ class RestClient(object): def __init__(self): """ """ - self.url_base = None # type: str + self.url_base = None # type: str self._active = False self._queue = Queue() - self._pool = None # type: Pool + self._pool = None # type: Pool self.proxies = None @@ -135,16 +135,16 @@ class RestClient(object): self._queue.join() def add_request( - self, - method: str, - path: str, - callback: Callable, - params: dict = None, - data: dict = None, - headers: dict = None, - on_failed: Callable = None, - on_error: Callable = None, - extra: Any = None + self, + method: str, + path: str, + callback: Callable, + params: dict = None, + data: dict = None, + headers: dict = None, + on_failed: Callable = None, + on_error: Callable = None, + extra: Any = None, ): """ Add a new request. @@ -160,15 +160,7 @@ class RestClient(object): :return: Request """ request = Request( - method, - path, - params, - data, - headers, - callback, - on_failed, - on_error, - extra + method, path, params, data, headers, callback, on_failed, on_error, extra ) self._queue.put(request) return request @@ -204,54 +196,34 @@ class RestClient(object): sys.stderr.write(str(request)) def on_error( - self, - exception_type: type, - exception_value: Exception, - tb, - request: Request + self, exception_type: type, exception_value: Exception, tb, request: Request ): """ Default on_error handler for Python exception. """ sys.stderr.write( - self.exception_detail(exception_type, - exception_value, - tb, - request) + self.exception_detail(exception_type, exception_value, tb, request) ) sys.excepthook(exception_type, exception_value, tb) def exception_detail( - self, - exception_type: type, - exception_value: Exception, - tb, - request: Request + self, exception_type: type, exception_value: Exception, tb, request: Request ): text = "[{}]: Unhandled RestClient Error:{}\n".format( - datetime.now().isoformat(), - exception_type + datetime.now().isoformat(), exception_type ) text += "request:{}\n".format(request) text += "Exception trace: \n" - text += "".join( - traceback.format_exception( - exception_type, - exception_value, - tb, - ) - ) + text += "".join(traceback.format_exception(exception_type, exception_value, tb)) return text def _process_request( - self, - request: Request, - session: requests.session - ): # type: (Request, requests.Session)->None + self, request: Request, session: requests.session + ): # type: (Request, requests.Session)->None """ Sending request to server and get result. """ - # noinspection PyBroadException + # noinspection PyBroadException try: request = self.sign(request) @@ -263,12 +235,12 @@ class RestClient(object): headers=request.headers, params=request.params, data=request.data, - proxies=self.proxies + proxies=self.proxies, ) request.response = response status_code = response.status_code - if status_code / 100 == 2: # 2xx都算成功,尽管交易所都用200 + if status_code / 100 == 2: # 2xx都算成功,尽管交易所都用200 jsonBody = response.json() request.callback(jsonBody, request) request.status = RequestStatus.success diff --git a/vnpy/api/websocket/websocket_client.py b/vnpy/api/websocket/websocket_client.py index a4e71b14..a0db6c11 100644 --- a/vnpy/api/websocket/websocket_client.py +++ b/vnpy/api/websocket/websocket_client.py @@ -123,9 +123,9 @@ class WebsocketClient(object): """""" self._ws = self._create_connection( self.host, - sslopt={'cert_reqs': ssl.CERT_NONE}, + sslopt={"cert_reqs": ssl.CERT_NONE}, http_proxy_host=self.proxy_host, - http_proxy_port=self.proxy_port + http_proxy_port=self.proxy_port, ) self.on_connected() @@ -166,7 +166,7 @@ class WebsocketClient(object): try: data = self.unpack_data(text) except ValueError as e: - print('websocket unable to parse data: ' + text) + print("websocket unable to parse data: " + text) raise e self.on_packet(data) @@ -211,7 +211,7 @@ class WebsocketClient(object): """""" ws = self._get_ws() if ws: - ws.send('ping', websocket.ABNF.OPCODE_PING) + ws.send("ping", websocket.ABNF.OPCODE_PING) @staticmethod def on_connected(): @@ -238,36 +238,20 @@ class WebsocketClient(object): """ Callback when exception raised. """ - sys.stderr.write( - self.exception_detail(exception_type, - exception_value, - tb) - ) + sys.stderr.write(self.exception_detail(exception_type, exception_value, tb)) return sys.excepthook(exception_type, exception_value, tb) - def exception_detail( - self, - exception_type: type, - exception_value: Exception, - tb - ): + def exception_detail(self, exception_type: type, exception_value: Exception, tb): """ Print detailed exception information. """ text = "[{}]: Unhandled WebSocket Error:{}\n".format( - datetime.now().isoformat(), - exception_type + datetime.now().isoformat(), exception_type ) text += "LastSentText:\n{}\n".format(self._last_sent_text) text += "LastReceivedText:\n{}\n".format(self._last_received_text) text += "Exception trace: \n" - text += "".join( - traceback.format_exception( - exception_type, - exception_value, - tb, - ) - ) + text += "".join(traceback.format_exception(exception_type, exception_value, tb)) return text def _record_last_sent_text(self, text: str): @@ -280,4 +264,4 @@ class WebsocketClient(object): """ Record last received text for debug purpose. """ - self._last_received_text = text[:1000] \ No newline at end of file + self._last_received_text = text[:1000] diff --git a/vnpy/app/cta_strategy/__init__.py b/vnpy/app/cta_strategy/__init__.py index 2bc3bf51..1e29016d 100644 --- a/vnpy/app/cta_strategy/__init__.py +++ b/vnpy/app/cta_strategy/__init__.py @@ -8,10 +8,11 @@ from .base import APP_NAME class CtaStrategyApp(BaseApp): """""" + app_name = APP_NAME app_module = __module__ app_path = Path(__file__).parent display_name = "CTA策略" engine_class = CtaEngine widget_name = "CtaManager" - icon_name = "cta.ico" \ No newline at end of file + icon_name = "cta.ico" diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index 4154b119..b008cfd0 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -4,6 +4,7 @@ from collections import defaultdict import numpy as np import matplotlib.pyplot as plt +import seaborn as sns from pandas import DataFrame from vnpy.trader.constant import Interval, Status, Direction, Exchange @@ -18,13 +19,16 @@ from .base import ( EngineType, StopOrder, BacktestingMode, - ORDER_CTA2VT + ORDER_CTA2VT, ) from .template import CtaTemplate +sns.set_style("whitegrid") + class BacktestingEngine: """""" + engine_type = EngineType.BACKTESTING gateway_name = "BACKTESTING" @@ -39,7 +43,7 @@ class BacktestingEngine: self.slippage = 0 self.size = 1 self.pricetick = 0 - self.capital = 1000000 + self.capital = 1_000_000 self.mode = BacktestingMode.BAR self.strategy = None @@ -92,17 +96,17 @@ class BacktestingEngine: self.daily_results.clear() def set_parameters( - self, - vt_symbol: str, - interval: Interval, - start: datetime, - rate: float, - slippage: float, - size: float, - pricetick: float, - capital: int = 0, - end: datetime = None, - mode: BacktestingMode = None + self, + vt_symbol: str, + interval: Interval, + start: datetime, + rate: float, + slippage: float, + size: float, + pricetick: float, + capital: int = 0, + end: datetime = None, + mode: BacktestingMode = None, ): """""" self.mode = mode @@ -124,10 +128,7 @@ class BacktestingEngine: def add_strategy(self, strategy_class: type, setting: dict): """""" self.strategy = strategy_class( - self, - strategy_class.__name__, - self.vt_symbol, - setting + self, strategy_class.__name__, self.vt_symbol, setting ) def load_data(self): @@ -135,18 +136,26 @@ class BacktestingEngine: self.output("开始加载历史数据") if self.mode == BacktestingMode.BAR: - s = DbBarData.select().where( - (DbBarData.vt_symbol == self.vt_symbol) & - (DbBarData.interval == self.interval) & - (DbBarData.datetime >= self.start) & - (DbBarData.datetime <= self.end) - ).order_by(DbBarData.datetime) + s = ( + DbBarData.select() + .where( + (DbBarData.vt_symbol == self.vt_symbol) + & (DbBarData.interval == self.interval) + & (DbBarData.datetime >= self.start) + & (DbBarData.datetime <= self.end) + ) + .order_by(DbBarData.datetime) + ) else: - s = DbTickData.select().where( - (DbTickData.vt_symbol == self.vt_symbol) & - (DbTickData.datetime >= self.start) & - (DbTickData.datetime <= self.end) - ).order_by(DbTickData.datetime) + s = ( + DbTickData.select() + .where( + (DbTickData.vt_symbol == self.vt_symbol) + & (DbTickData.datetime >= self.start) + & (DbTickData.datetime <= self.end) + ) + .order_by(DbTickData.datetime) + ) self.history_data = list(s) @@ -164,7 +173,7 @@ class BacktestingEngine: # Use the first [days] of history data for initializing strategy day_count = 0 for ix, data in enumerate(self.history_data): - if (self.datetime and data.datetime.day != self.datetime.day): + if self.datetime and data.datetime.day != self.datetime.day: day_count += 1 if day_count >= self.days: break @@ -205,11 +214,7 @@ class BacktestingEngine: for daily_result in self.daily_results.values(): daily_result.calculate_pnl( - pre_close, - start_pos, - self.size, - self.rate, - self.slippage + pre_close, start_pos, self.size, self.rate, self.slippage ) pre_close = daily_result.close_price @@ -236,78 +241,82 @@ class BacktestingEngine: # Calculate balance related time series data df["balance"] = df["net_pnl"].cumsum() + self.capital - df["return"] = (np.log(df["balance}" - np.log(df["balance"].shift(1))).fillna(0) - df["highlevel"] = df["balance"].rolling(min_periods=1,window=len(df),center=False).max() - df["drawdown"] = df["balance"] - df["highlevel"] + df["return"] = (np.log(df["balance"] - np.log(df["balance"].shift(1)))).fillna( + 0 + ) + df["highlevel"] = ( + df["balance"].rolling(min_periods=1, window=len(df), center=False).max() + ) + df["drawdown"] = df["balance"] - df["highlevel"] df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100 - + # Calculate statistics value start_date = df.index[0] end_date = df.index[-1] total_days = len(df) - profit_days = len(df[df["netPnl"]>0]) - loss_days = len(df[df["netPnl"]<0]) - + profit_days = len(df[df["netPnl"] > 0]) + loss_days = len(df[df["netPnl"] < 0]) + end_balance = df["balance"].iloc[-1] max_drawdown = df["drawdown"].min() max_ddpercent = df["ddpercent"].min() - + total_net_pnl = df["net_pnl"].sum() daily_net_pnl = total_net_pnl / total_days - + total_commission = df["commission"].sum() daily_commission = total_commission / total_days - + total_slippage = df["slippage"].sum() daily_slippage = total_slippage / total_days - + total_turnover = df["turnover"].sum() daily_turnover = total_turnover / total_days - + total_trade_count = df["trade_count"].sum() daily_trade_count = total_trade_count / total_days - - total_return = (end_balance/self.capital - 1) * 100 + + total_return = (end_balance / self.capital - 1) * 100 annual_return = total_return / total_days * 240 daily_return = df["return"].mean() * 100 return_std = df["return"].std() * 100 - + if return_std: sharpe_ratio = daily_return / return_std * np.sqrt(240) else: sharpe_ratio = 0 - + # Output - # 输出统计结果 + # 输出统计结果 self.output("-" * 30) self.output(f"首个交易日:\t{start_date}") self.output(f"最后交易日:\t{end_date}") - + self.output(f"总交易日:\t{total_days}") self.output(f"盈利交易日:\t{profit_days}") self.output(f"亏损交易日:\t{loss_days}") - + self.output(f"起始资金:\t{self.capital:,.2f}") self.output(f"结束资金:\t{end_balance:,.2f}") - + self.output(f"总收益率:\t{total_return:,.2f}%") self.output(f"年化收益:\t{annual_return:,.2f}%") - self.output(f"最大回撤: \t{max_drawdown:,.2f}%") - self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%") - + self.output(f"最大回撤: \t{max_drawdown:,.2f}%") + self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%") + self.output(f"总盈亏:\t{total_net_pnl:,.2f}%") self.output(f"总手续费:\t{total_commission:,.2f}") self.output(f"总滑点:\t{total_slippage:,.2f}") self.output(f"总成交金额:\t{total_turnover:,.2f}") self.output(f"总成交笔数:\t{total_trade_count}") - + self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}") self.output(f"日均手续费:\t{daily_commission:,.2f}") self.output(f"日均滑点:\t{daily_slippage:,.2f}") self.output(f"日均成交金额:\t{daily_turnover:,.2f}") self.output(f"日均成交笔数:\t{daily_trade_count}") - + self.output(f"日均收益率:\t{daily_return:,.2f}%") self.output(f"收益标准差:\t{return_std:,.2f}%") self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}") @@ -335,33 +344,33 @@ class BacktestingEngine: "annual_return": annual_return, "daily_return": daily_return, "return_std": return_std, - "sharpe_ratio": sharpe_ratio + "sharpe_ratio": sharpe_ratio, } - + return statistics def show_chart(self, df: DataFrame = None): """""" if not df: df = self.daily_df - + fig = plt.figure(figsize=(10, 16)) - + balance_plot = plt.subplot(4, 1, 1) - balance_plot.set_title('Balance') - df['balance'].plot(legend=True) - + balance_plot.set_title("Balance") + df["balance"].plot(legend=True) + drawdown_plot = plt.subplot(4, 1, 2) - drawdown_plot.set_title('Drawdown') - drawdown_plot.fill_between(range(len(df)), df['drawdown'].values) - + drawdown_plot.set_title("Drawdown") + drawdown_plot.fill_between(range(len(df)), df["drawdown"].values) + pnl_plot = plt.subplot(4, 1, 3) - pnl_plot.set_title('Daily Pnl') - df['net_pnl'].plot(kind='bar', legend=False, grid=False, xticks=[]) + pnl_plot.set_title("Daily Pnl") + df["net_pnl"].plot(kind="bar", legend=False, grid=False, xticks=[]) distribution_plot = plt.subplot(4, 1, 4) - distribution_plot.set_title('Daily Pnl Distribution') - df['net_pnl'].hist(bins=50) + distribution_plot.set_title("Daily Pnl Distribution") + df["net_pnl"].hist(bins=50) plt.show() @@ -421,12 +430,14 @@ class BacktestingEngine: # Check whether limit orders can be filled. long_cross = ( order.direction == Direction.LONG - and order.price >= long_cross_price and long_cross_price > 0 + and order.price >= long_cross_price + and long_cross_price > 0 ) short_cross = ( order.direction == Direction.SHORT - and order.price <= short_cross_price and short_cross_price > 0 + and order.price <= short_cross_price + and short_cross_price > 0 ) if not long_cross and not short_cross: @@ -459,7 +470,7 @@ class BacktestingEngine: price=trade_price, volume=order.volume, time=self.datetime.strftime("%H:%M:%S"), - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) trade.datetime = self.datetime @@ -510,7 +521,7 @@ class BacktestingEngine: price=stop_order.price, volume=stop_order.volume, status=Status.ALLTRADED, - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.limit_orders[order.vt_orderid] = order @@ -535,7 +546,7 @@ class BacktestingEngine: price=trade_price, volume=order.volume, time=self.datetime.strftime("%H:%M:%S"), - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) trade.datetime = self.datetime @@ -553,11 +564,7 @@ class BacktestingEngine: self.strategy.on_trade(trade) def load_bar( - self, - vt_symbol: str, - days: int, - interval: Interval, - callback: Callable + self, vt_symbol: str, days: int, interval: Interval, callback: Callable ): """""" self.days = days @@ -569,12 +576,12 @@ class BacktestingEngine: self.callback = callback def send_order( - self, - strategy: CtaTemplate, - order_type: CtaOrderType, - price: float, - volume: float, - stop: bool = False + self, + strategy: CtaTemplate, + order_type: CtaOrderType, + price: float, + volume: float, + stop: bool = False, ): """""" if stop: @@ -582,12 +589,7 @@ class BacktestingEngine: else: return self.send_limit_order(order_type, price, volume) - def send_stop_order( - self, - order_type: CtaOrderType, - price: float, - volume: float - ): + def send_stop_order(self, order_type: CtaOrderType, price: float, volume: float): """""" self.stop_order_count += 1 @@ -597,7 +599,7 @@ class BacktestingEngine: price=price, volume=volume, stop_orderid=f"{STOPORDER_PREFIX}.{self.stop_order_count}", - strategy_name=self.strategy.strategy_name + strategy_name=self.strategy.strategy_name, ) self.active_stop_orders[stop_order.stop_orderid] = stop_order @@ -605,12 +607,7 @@ class BacktestingEngine: return stop_order.stop_orderid - def send_limit_order( - self, - order_type: CtaOrderType, - price: float, - volume: float - ): + def send_limit_order(self, order_type: CtaOrderType, price: float, volume: float): """""" self.limit_order_count += 1 direction, offset = ORDER_CTA2VT[order_type] @@ -624,7 +621,7 @@ class BacktestingEngine: price=price, volume=volume, status=Status.NOTTRADED, - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.active_limit_orders[order.vt_orderid] = order @@ -724,19 +721,17 @@ class DailyResult: self.trades.append(trade) def calculate_pnl( - self, - pre_close: float, - start_pos: float, - size: int, - rate: float, - slippage: float + self, + pre_close: float, + start_pos: float, + size: int, + rate: float, + slippage: float, ): """""" # Holding pnl is the pnl from holding position at day start self.start_pos = self.end_pos = start_pos - self.holding_pnl = self.start_pos * ( - self.close_price - self.pre_close - ) * size + self.holding_pnl = self.start_pos * (self.close_price - self.pre_close) * size # Trading pnl is the pnl from new trade during the day self.trade_count = len(self.trades) @@ -749,9 +744,7 @@ class DailyResult: pos_change -= trade.volume turnover = trade.price * trade.volume * size - self.trading_pnl += pos_change * ( - self.close_price - trade.price - ) * size + self.trading_pnl += pos_change * (self.close_price - trade.price) * size self.end_pos += pos_change self.turnover += turnover self.commission += turnover * rate @@ -760,3 +753,39 @@ class DailyResult: # Net pnl takes account of commission and slippage cost self.total_pnl = self.trading_pnl + self.holding_pnl self.net_pnl = self.total_pnl - self.commission - self.slippage + + +class OptimizationSetting: + """ + Setting for runnning optimization. + """ + + def __init__(self): + """""" + self.params = {} + self.target = "" + + def add_parameter( + self, name: str, start: float, end: float = None, step: float = None + ): + """""" + if not end and not step: + self.params[name] = [start] + return + + if start >= end: + print("参数优化起始点必须小于终止点") + return + + if step <= 0: + print("参数优化步进必须大于0") + return + + value = start + value_list = [] + + while value <= end: + value_list.append(value) + value += step + + self.params[name] = value_list diff --git a/vnpy/app/cta_strategy/base.py b/vnpy/app/cta_strategy/base.py index 1e8d6bd2..0b3c28d4 100644 --- a/vnpy/app/cta_strategy/base.py +++ b/vnpy/app/cta_strategy/base.py @@ -51,17 +51,13 @@ class StopOrder: self.direction, self.offset = ORDER_CTA2VT[self.order_type] -EVENT_CTA_LOG = 'eCtaLog' -EVENT_CTA_STRATEGY = 'eCtaStrategy' -EVENT_CTA_STOPORDER = 'eCtaStopOrder' +EVENT_CTA_LOG = "eCtaLog" +EVENT_CTA_STRATEGY = "eCtaStrategy" +EVENT_CTA_STOPORDER = "eCtaStopOrder" ORDER_CTA2VT = { - CtaOrderType.BUY: (Direction.LONG, - Offset.OPEN), - CtaOrderType.SELL: (Direction.SHORT, - Offset.CLOSE), - CtaOrderType.SHORT: (Direction.SHORT, - Offset.OPEN), - CtaOrderType.COVER: (Direction.LONG, - Offset.CLOSE), + CtaOrderType.BUY: (Direction.LONG, Offset.OPEN), + CtaOrderType.SELL: (Direction.SHORT, Offset.CLOSE), + CtaOrderType.SHORT: (Direction.SHORT, Offset.OPEN), + CtaOrderType.COVER: (Direction.LONG, Offset.CLOSE), } diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index d69a525d..2e9130ec 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -15,7 +15,7 @@ from vnpy.trader.object import ( CancelRequest, SubscribeRequest, LogData, - TickData + TickData, ) from vnpy.trader.event import EVENT_TICK, EVENT_ORDER, EVENT_TRADE from vnpy.trader.constant import Direction, Offset, Exchange, PriceType, Interval @@ -31,36 +31,32 @@ from .base import ( ORDER_CTA2VT, EVENT_CTA_LOG, EVENT_CTA_STRATEGY, - EVENT_CTA_STOPORDER + EVENT_CTA_STOPORDER, ) class CtaEngine(BaseEngine): """""" - engine_type = EngineType.LIVE # live trading engine + + engine_type = EngineType.LIVE # live trading engine filename = "CtaStrategy.vt" def __init__(self, main_engine: MainEngine, event_engine: EventEngine): """""" - super(CtaEngine, - self).__init__(main_engine, - event_engine, - "CtaStrategy") + super(CtaEngine, self).__init__(main_engine, event_engine, "CtaStrategy") - self.setting_file = None # setting file object + self.setting_file = None # setting file object - self.classes = {} # class_name: stategy_class - self.strategies = {} # strategy_name: strategy + self.classes = {} # class_name: stategy_class + self.strategies = {} # strategy_name: strategy - self.symbol_strategy_map = defaultdict(list) # vt_symbol: strategy list - self.orderid_strategy_map = {} # vt_orderid: strategy - self.strategy_orderid_map = defaultdict( - set - ) # strategy_name: orderid list + self.symbol_strategy_map = defaultdict(list) # vt_symbol: strategy list + self.orderid_strategy_map = {} # vt_orderid: strategy + self.strategy_orderid_map = defaultdict(set) # strategy_name: orderid list - self.stop_order_count = 0 # for generating stop_orderid - self.stop_orders = {} # stop_orderid: stop_order + self.stop_order_count = 0 # for generating stop_orderid + self.stop_orders = {} # stop_orderid: stop_order def init_engine(self): """ @@ -131,12 +127,10 @@ class CtaEngine(BaseEngine): continue long_triggered = ( - so.direction == Direction.LONG - and tick.last_price >= stop_order.price + so.direction == Direction.LONG and tick.last_price >= stop_order.price ) short_triggered = ( - so.direction == Direction.SHORT - and tick.last_price <= stop_order.price + so.direction == Direction.SHORT and tick.last_price <= stop_order.price ) if long_triggered or short_triggered: @@ -157,10 +151,7 @@ class CtaEngine(BaseEngine): price = tick.bid_price_5 vt_orderid = self.send_limit_order( - strategy, - stop_order.order_type, - price, - stop_order.volume + strategy, stop_order.order_type, price, stop_order.volume ) # Update stop order status if placed successfully @@ -177,17 +168,15 @@ class CtaEngine(BaseEngine): stop_order.vt_orderid = vt_orderid self.call_strategy_func( - strategy, - strategy.on_stop_order, - stop_order + strategy, strategy.on_stop_order, stop_order ) def send_limit_order( - self, - strategy: CtaTemplate, - order_type: CtaOrderType, - price: float, - volume: float + self, + strategy: CtaTemplate, + order_type: CtaOrderType, + price: float, + volume: float, ): """ Send a new order. @@ -207,12 +196,9 @@ class CtaEngine(BaseEngine): offset=offset, price_type=PriceType.LIMIT, price=price, - volume=volume - ) - vt_orderid = self.main_engine.send_limit_order( - req, - contract.gateway_name + volume=volume, ) + vt_orderid = self.main_engine.send_limit_order(req, contract.gateway_name) # Save relationship between orderid and strategy. self.orderid_strategy_map[vt_orderid] = strategy @@ -223,11 +209,11 @@ class CtaEngine(BaseEngine): return vt_orderid def send_stop_order( - self, - strategy: CtaTemplate, - order_type: CtaOrderType, - price: float, - volume: float + self, + strategy: CtaTemplate, + order_type: CtaOrderType, + price: float, + volume: float, ): """ Send a new order. @@ -243,7 +229,7 @@ class CtaEngine(BaseEngine): price=price, volume=volume, stop_orderid=stop_orderid, - strategy_name=strategy.strategy_name + strategy_name=strategy.strategy_name, ) self.stop_orders[stop_orderid] = stop_order @@ -289,12 +275,12 @@ class CtaEngine(BaseEngine): self.call_strategy_func(strategy, strategy.on_stop_order, stop_order) def send_order( - self, - strategy: CtaTemplate, - order_type: CtaOrderType, - price: float, - volume: float, - stop: bool + self, + strategy: CtaTemplate, + order_type: CtaOrderType, + price: float, + volume: float, + stop: bool, ): """ """ @@ -327,11 +313,7 @@ class CtaEngine(BaseEngine): return self.engine_type def load_bar( - self, - vt_symbol: str, - days: int, - interval: Interval, - callback: Callable + self, vt_symbol: str, days: int, interval: Interval, callback: Callable ): """""" pass @@ -341,10 +323,7 @@ class CtaEngine(BaseEngine): pass def call_strategy_func( - self, - strategy: CtaTemplate, - func: Callable, - params: Any = None + self, strategy: CtaTemplate, func: Callable, params: Any = None ): """ Call function of a strategy and catch any exception raised. @@ -362,11 +341,7 @@ class CtaEngine(BaseEngine): self.write_log(msg, strategy) def add_strategy( - self, - class_name: str, - strategy_name: str, - vt_symbol: str, - setting: dict + self, class_name: str, strategy_name: str, vt_symbol: str, setting: dict ): """ Add a new strategy. @@ -462,29 +437,18 @@ class CtaEngine(BaseEngine): Load strategy class from source code. """ path1 = Path(__file__).parent.joinpath("strategies") - self.load_strategy_class_from_folder( - path1, - "vnpy.app.cta_strategy.strategies" - ) + self.load_strategy_class_from_folder(path1, "vnpy.app.cta_strategy.strategies") path2 = Path.cwd().joinpath("strategies") self.load_strategy_class_from_folder(path2, "strategies") - def load_strategy_class_from_folder( - self, - path: Path, - module_name: str = "" - ): + def load_strategy_class_from_folder(self, path: Path, module_name: str = ""): """ Load strategy class from certain folder. """ for dirpath, dirnames, filenames in os.walk(path): for filename in filenames: - module_name = ".".join( - [module_name, - filename.replace(".py", - "")] - ) + module_name = ".".join([module_name, filename.replace(".py", "")]) self.load_strategy_class_from_module(module_name) def load_strategy_class_from_module(self, module_name: str): @@ -566,7 +530,7 @@ class CtaEngine(BaseEngine): strategy.__class__.__name__, strategy_name, strategy.vt_symbol, - setting + setting, ) self.setting_file.sync() @@ -611,4 +575,4 @@ class CtaEngine(BaseEngine): log = LogData(msg=msg, gateway_name="CtaStrategy") event = Event(type=EVENT_CTA_LOG, data=log) - self.event_engine.put(event) \ No newline at end of file + self.event_engine.put(event) diff --git a/vnpy/app/cta_strategy/strategies/double_ma_strategy.py b/vnpy/app/cta_strategy/strategies/double_ma_strategy.py index e292503e..513603ee 100644 --- a/vnpy/app/cta_strategy/strategies/double_ma_strategy.py +++ b/vnpy/app/cta_strategy/strategies/double_ma_strategy.py @@ -16,8 +16,6 @@ class DoubleMaStrategy(CtaTemplate): def __init__(self, cta_engine, strategy_name, vt_symbol, setting): """""" - super(DoubleMaStrategy, - self).__init__(cta_engine, - strategy_name, - vt_symbol, - setting) + super(DoubleMaStrategy, self).__init__( + cta_engine, strategy_name, vt_symbol, setting + ) diff --git a/vnpy/app/cta_strategy/template.py b/vnpy/app/cta_strategy/template.py index 31d959ce..da9bc718 100644 --- a/vnpy/app/cta_strategy/template.py +++ b/vnpy/app/cta_strategy/template.py @@ -1,7 +1,7 @@ """""" from abc import ABC -from typing import Any +from typing import Any, Callable from vnpy.trader.object import TickData, OrderData, TradeData, BarData from vnpy.trader.constant import Interval @@ -17,11 +17,7 @@ class CtaTemplate(ABC): variables = [] def __init__( - self, - cta_engine: Any, - strategy_name: str, - vt_symbol: str, - setting: dict + self, cta_engine: Any, strategy_name: str, vt_symbol: str, setting: dict ): """""" self.cta_engine = cta_engine @@ -84,7 +80,7 @@ class CtaTemplate(ABC): "class_name": self.__class__.__name__, "author": self.author, "parameters": self.get_parameters(), - "variables": self.get_variables() + "variables": self.get_variables(), } return strategy_data @@ -155,11 +151,7 @@ class CtaTemplate(ABC): return self.send_order(CtaOrderType.COVER, price, volume, stop) def send_order( - self, - order_type: CtaOrderType, - price: float, - volume: float, - stop: bool = False + self, order_type: CtaOrderType, price: float, volume: float, stop: bool = False ): """ Send a new order. @@ -191,14 +183,14 @@ class CtaTemplate(ABC): return self.cta_engine.get_engine_type() def load_bar( - self, - days: int, - interval: Interval = Interval.MINUTE, - callback=self.on_bar + self, days: int, interval: Interval = Interval.MINUTE, callback: Callable = None ): """ Load historical bar data for initializing strategy. """ + if not callback: + callback = self.on_bar + self.cta_engine.load_bar(self.vt_symbol, days, interval, callback) def load_tick(self, days: int): diff --git a/vnpy/app/cta_strategy/ui/__init__.py b/vnpy/app/cta_strategy/ui/__init__.py index ce7bb539..592d401a 100644 --- a/vnpy/app/cta_strategy/ui/__init__.py +++ b/vnpy/app/cta_strategy/ui/__init__.py @@ -1 +1 @@ -from .widget import CtaManager \ No newline at end of file +from .widget import CtaManager diff --git a/vnpy/app/cta_strategy/ui/widget.py b/vnpy/app/cta_strategy/ui/widget.py index b8ca175c..8dbe088b 100644 --- a/vnpy/app/cta_strategy/ui/widget.py +++ b/vnpy/app/cta_strategy/ui/widget.py @@ -11,6 +11,7 @@ from ..base import APP_NAME, EVENT_CTA_LOG, EVENT_CTA_STOPORDER, EVENT_CTA_STRAT class CtaManager(QtWidgets.QWidget): """""" + signal_log = QtCore.pyqtSignal(Event) signal_strategy = QtCore.pyqtSignal(Event) @@ -61,10 +62,7 @@ class CtaManager(QtWidgets.QWidget): self.log_monitor = LogMonitor(self.main_engine, self.event_engine) - self.stop_order_monitor = StopOrderMonitor( - self.main_engine, - self.event_engine - ) + self.stop_order_monitor = StopOrderMonitor(self.main_engine, self.event_engine) # Set layout hbox1 = QtWidgets.QHBoxLayout() @@ -88,18 +86,13 @@ class CtaManager(QtWidgets.QWidget): def update_class_combo(self): """""" - self.class_combo.addItems( - self.cta_engine.get_all_strategy_class_names() - ) + self.class_combo.addItems(self.cta_engine.get_all_strategy_class_names()) def register_event(self): """""" self.signal_strategy.connect(self.process_strategy_event) - self.event_engine.register( - EVENT_CTA_STRATEGY, - self.signal_strategy.emit - ) + self.event_engine.register(EVENT_CTA_STRATEGY, self.signal_strategy.emit) def process_strategy_event(self, event): """ @@ -136,12 +129,7 @@ class CtaManager(QtWidgets.QWidget): vt_symbol = setting.pop("vt_symbol") strategy_name = setting.pop("strategy_name") - self.cta_engine.add_strategy( - class_name, - strategy_name, - vt_symbol, - setting - ) + self.cta_engine.add_strategy(class_name, strategy_name, vt_symbol, setting) def show(self): """""" @@ -153,12 +141,7 @@ class StrategyManager(QtWidgets.QFrame): Manager for a strategy """ - def __init__( - self, - cta_manager: CtaManager, - cta_engine: CtaEngine, - data: dict - ): + def __init__(self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict): """""" super(StrategyManager, self).__init__() @@ -277,9 +260,7 @@ class DataMonitor(QtWidgets.QTableWidget): self.setHorizontalHeaderLabels(labels) self.setRowCount(1) - self.verticalHeader().setSectionResizeMode( - QtWidgets.QHeaderView.Stretch - ) + self.verticalHeader().setSectionResizeMode(QtWidgets.QHeaderView.Stretch) self.verticalHeader().setVisible(False) self.setEditTriggers(self.NoEditTriggers) @@ -320,51 +301,20 @@ class StopOrderMonitor(BaseMonitor): """ Monitor for local stop order. """ + event_type = EVENT_CTA_STOPORDER data_key = "stop_orderid" sorting = True headers = { - "stop_orderid": { - "display": "停止委托号", - "cell": BaseCell, - "update": False - }, - "vt_orderid": { - "display": "限价委托号", - "cell": BaseCell, - "update": True - }, - "vt_symbol": { - "display": "代码", - "cell": BaseCell, - "update": False - }, - "order_type": { - "display": "类型", - "cell": EnumCell, - "update": False - }, - "price": { - "display": "价格", - "cell": BaseCell, - "update": False - }, - "volume": { - "display": "数量", - "cell": BaseCell, - "update": True - }, - "status": { - "display": "状态", - "cell": EnumCell, - "update": True - }, - "strategy": { - "display": "策略名", - "cell": StrategyCell, - "update": True - } + "stop_orderid": {"display": "停止委托号", "cell": BaseCell, "update": False}, + "vt_orderid": {"display": "限价委托号", "cell": BaseCell, "update": True}, + "vt_symbol": {"display": "代码", "cell": BaseCell, "update": False}, + "order_type": {"display": "类型", "cell": EnumCell, "update": False}, + "price": {"display": "价格", "cell": BaseCell, "update": False}, + "volume": {"display": "数量", "cell": BaseCell, "update": True}, + "status": {"display": "状态", "cell": EnumCell, "update": True}, + "strategy": {"display": "策略名", "cell": StrategyCell, "update": True}, } def init_ui(self): @@ -373,30 +323,21 @@ class StopOrderMonitor(BaseMonitor): """ super(StopOrderMonitor, self).init_ui() - self.horizontalHeader().setSectionResizeMode( - QtWidgets.QHeaderView.Stretch - ) + self.horizontalHeader().setSectionResizeMode(QtWidgets.QHeaderView.Stretch) class LogMonitor(BaseMonitor): """ Monitor for log data. """ + event_type = EVENT_CTA_LOG data_key = "" sorting = False headers = { - "time": { - "display": "时间", - "cell": TimeCell, - "update": False - }, - "msg": { - "display": "信息", - "cell": MsgCell, - "update": False - } + "time": {"display": "时间", "cell": TimeCell, "update": False}, + "msg": {"display": "信息", "cell": MsgCell, "update": False}, } def init_ui(self): @@ -405,10 +346,7 @@ class LogMonitor(BaseMonitor): """ super(LogMonitor, self).init_ui() - self.horizontalHeader().setSectionResizeMode( - 1, - QtWidgets.QHeaderView.Stretch - ) + self.horizontalHeader().setSectionResizeMode(1, QtWidgets.QHeaderView.Stretch) def insert_new_row(self, data): """ @@ -423,12 +361,7 @@ class SettingEditor(QtWidgets.QDialog): For creating new strategy and editing strategy parameters. """ - def __init__( - self, - parameters: dict, - strategy_name: str = "", - class_name: str = "" - ): + def __init__(self, parameters: dict, strategy_name: str = "", class_name: str = ""): """""" super(SettingEditor, self).__init__() diff --git a/vnpy/event/__init__.py b/vnpy/event/__init__.py index db211d01..2aad05ac 100644 --- a/vnpy/event/__init__.py +++ b/vnpy/event/__init__.py @@ -1 +1 @@ -from .engine import Event, EventEngine, EVENT_TIMER \ No newline at end of file +from .engine import Event, EventEngine, EVENT_TIMER diff --git a/vnpy/gateway/bitmex/__init__.py b/vnpy/gateway/bitmex/__init__.py index e11ff09c..d68cbad5 100644 --- a/vnpy/gateway/bitmex/__init__.py +++ b/vnpy/gateway/bitmex/__init__.py @@ -1 +1 @@ -from .bitmex_gateway import BitmexGateway \ No newline at end of file +from .bitmex_gateway import BitmexGateway diff --git a/vnpy/gateway/bitmex/bitmex_gateway.py b/vnpy/gateway/bitmex/bitmex_gateway.py index f4dee0f8..55693f01 100644 --- a/vnpy/gateway/bitmex/bitmex_gateway.py +++ b/vnpy/gateway/bitmex/bitmex_gateway.py @@ -31,7 +31,7 @@ from vnpy.trader.object import ( TradeData, PositionData, AccountData, - ContractData + ContractData, ) from vnpy.trader.constant import Direction, Status, PriceType, Exchange, Product @@ -46,7 +46,7 @@ STATUS_BITMEX2VT = { "Partially filled": Status.PARTTRADED, "Filled": Status.ALLTRADED, "Canceled": Status.CANCELLED, - "Rejected": Status.REJECTED + "Rejected": Status.REJECTED, } DIRECTION_VT2BITMEX = {Direction.LONG: "Buy", Direction.SHORT: "Sell"} @@ -64,10 +64,9 @@ class BitmexGateway(BaseGateway): "key": "", "secret": "", "session": 3, - "server": ["REAL", - "TESTNET"], + "server": ["REAL", "TESTNET"], "proxy_host": "127.0.0.1", - "proxy_port": 1080 + "proxy_port": 1080, } def __init__(self, event_engine): @@ -86,14 +85,7 @@ class BitmexGateway(BaseGateway): proxy_host = setting["proxy_host"] proxy_port = setting["proxy_port"] - self.rest_api.connect( - key, - secret, - session, - server, - proxy_host, - proxy_port - ) + self.rest_api.connect(key, secret, session, server, proxy_host, proxy_port) self.ws_api.connect(key, secret, server, proxy_host, proxy_port) @@ -138,7 +130,7 @@ class BitmexRestApi(RestClient): self.key = "" self.secret = "" - self.order_count = 1000000 + self.order_count = 1_000_000 self.connect_time = 0 def sign(self, request): @@ -161,9 +153,7 @@ class BitmexRestApi(RestClient): msg = request.method + "/api/v1" + path + str(expires) + request.data signature = hmac.new( - self.secret, - msg.encode(), - digestmod=hashlib.sha256 + self.secret, msg.encode(), digestmod=hashlib.sha256 ).hexdigest() # Add headers @@ -172,20 +162,20 @@ class BitmexRestApi(RestClient): "Accept": "application/json", "api-key": self.key, "api-expires": str(expires), - "api-signature": signature + "api-signature": signature, } request.headers = headers return request def connect( - self, - key: str, - secret: str, - session: int, - server: str, - proxy_host: str, - proxy_port: int + self, + key: str, + secret: str, + session: int, + server: str, + proxy_host: str, + proxy_port: int, ): """ Initialize connection to REST server. @@ -193,9 +183,9 @@ class BitmexRestApi(RestClient): self.key = key self.secret = secret.encode() - self.connect_time = int( - datetime.now().strftime("%y%m%d%H%M%S") - ) * self.order_count + self.connect_time = ( + int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count + ) if server == "REAL": self.init(REST_HOST, proxy_host, proxy_port) @@ -204,7 +194,7 @@ class BitmexRestApi(RestClient): self.start(session) - self.gateway.write_log(u"REST API启动成功") + self.gateway.write_log("REST API启动成功") def send_order(self, req: SubscribeRequest): """""" @@ -217,7 +207,7 @@ class BitmexRestApi(RestClient): "ordType": PRICETYPE_VT2BITMEX[req.price_type], "price": req.price, "orderQty": int(req.volume), - "clOrdID": orderid + "clOrdID": orderid, } # Only add price for limit order. @@ -268,11 +258,7 @@ class BitmexRestApi(RestClient): self.gateway.write_log(msg) def on_send_order_error( - self, - exception_type: type, - exception_value: Exception, - tb, - request: Request + self, exception_type: type, exception_value: Exception, tb, request: Request ): """ Callback when sending order caused exception. @@ -290,11 +276,7 @@ class BitmexRestApi(RestClient): pass def on_cancel_order_error( - self, - exception_type: type, - exception_value: Exception, - tb, - request: Request + self, exception_type: type, exception_value: Exception, tb, request: Request ): """ Callback when cancelling order failed on server. @@ -315,11 +297,7 @@ class BitmexRestApi(RestClient): self.gateway.write_log(msg) def on_error( - self, - exception_type: type, - exception_value: Exception, - tb, - request: Request + self, exception_type: type, exception_value: Exception, tb, request: Request ): """ Callback to handler request exception. @@ -328,10 +306,7 @@ class BitmexRestApi(RestClient): self.gateway.write_log(msg) sys.stderr.write( - self.exception_detail(exception_type, - exception_value, - tb, - request) + self.exception_detail(exception_type, exception_value, tb, request) ) @@ -355,7 +330,7 @@ class BitmexWebsocketApi(WebsocketClient): "order": self.on_order, "position": self.on_position, "margin": self.on_account, - "instrument": self.on_contract + "instrument": self.on_contract, } self.ticks = {} @@ -364,12 +339,7 @@ class BitmexWebsocketApi(WebsocketClient): self.trades = set() def connect( - self, - key: str, - secret: str, - server: str, - proxy_host: str, - proxy_port: int + self, key: str, secret: str, server: str, proxy_host: str, proxy_port: int ): """""" self.key = key @@ -391,23 +361,23 @@ class BitmexWebsocketApi(WebsocketClient): exchange=req.exchange, name=req.symbol, datetime=datetime.now(), - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.ticks[req.symbol] = tick def on_connected(self): """""" - self.gateway.write_log(u"Websocket API连接成功") + self.gateway.write_log("Websocket API连接成功") self.authenticate() def on_disconnected(self): """""" - self.gateway.write_log(u"Websocket API连接断开") + self.gateway.write_log("Websocket API连接断开") def on_packet(self, packet: dict): """""" if "error" in packet: - self.gateway.write_log(u"Websocket API报错:%s" % packet["error"]) + self.gateway.write_log("Websocket API报错:%s" % packet["error"]) if "not valid" in packet["error"]: self.active = False @@ -418,7 +388,7 @@ class BitmexWebsocketApi(WebsocketClient): if success: if req["op"] == "authKey": - self.gateway.write_log(u"Websocket API验证授权成功") + self.gateway.write_log("Websocket API验证授权成功") self.subscribe_topic() elif "table" in packet: @@ -436,11 +406,7 @@ class BitmexWebsocketApi(WebsocketClient): msg = f"触发异常,状态码:{exception_type},信息:{exception_value}" self.gateway.write_log(msg) - sys.stderr.write( - self.exception_detail(exception_type, - exception_value, - tb) - ) + sys.stderr.write(self.exception_detail(exception_type, exception_value, tb)) def authenticate(self): """ @@ -451,9 +417,7 @@ class BitmexWebsocketApi(WebsocketClient): path = "/realtime" msg = method + path + str(expires) signature = hmac.new( - self.secret, - msg.encode(), - digestmod=hashlib.sha256 + self.secret, msg.encode(), digestmod=hashlib.sha256 ).hexdigest() req = {"op": "authKey", "args": [self.key, expires, signature]} @@ -464,8 +428,7 @@ class BitmexWebsocketApi(WebsocketClient): Subscribe to all private topics. """ req = { - "op": - "subscribe", + "op": "subscribe", "args": [ "instrument", "trade", @@ -473,8 +436,8 @@ class BitmexWebsocketApi(WebsocketClient): "execution", "order", "position", - "margin" - ] + "margin", + ], } self.send_packet(req) @@ -486,10 +449,7 @@ class BitmexWebsocketApi(WebsocketClient): return tick.last_price = d["price"] - tick.datetime = datetime.strptime( - d["timestamp"], - "%Y-%m-%dT%H:%M:%S.%fZ" - ) + tick.datetime = datetime.strptime(d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") self.gateway.on_tick(copy(tick)) def on_depth(self, d): @@ -509,10 +469,7 @@ class BitmexWebsocketApi(WebsocketClient): tick.__setattr__("ask_price_%s" % (n + 1), price) tick.__setattr__("ask_volume_%s" % (n + 1), volume) - tick.datetime = datetime.strptime( - d["timestamp"], - "%Y-%m-%dT%H:%M:%S.%fZ" - ) + tick.datetime = datetime.strptime(d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") self.gateway.on_tick(copy(tick)) def on_trade(self, d): @@ -540,7 +497,7 @@ class BitmexWebsocketApi(WebsocketClient): price=d["lastPx"], volume=d["lastQty"], time=d["timestamp"][11:19], - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.gateway.on_trade(trade) @@ -568,7 +525,7 @@ class BitmexWebsocketApi(WebsocketClient): price=d["price"], volume=d["orderQty"], time=d["timestamp"][11:19], - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.orders[sysid] = order @@ -584,7 +541,7 @@ class BitmexWebsocketApi(WebsocketClient): exchange=Exchange.BITMEX, direction=Direction.NET, volume=d["currentQty"], - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.gateway.on_position(position) @@ -594,10 +551,7 @@ class BitmexWebsocketApi(WebsocketClient): accountid = str(d["account"]) account = self.accounts.get(accountid, None) if not account: - account = AccountData( - accountid=accountid, - gateway_name=self.gateway_name - ) + account = AccountData(accountid=accountid, gateway_name=self.gateway_name) self.accounts[accountid] = account account.balance = d.get("marginBalance", account.balance) @@ -621,7 +575,7 @@ class BitmexWebsocketApi(WebsocketClient): product=Product.FUTURES, pricetick=d["tickSize"], size=d["lotSize"], - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.gateway.on_contract(contract) diff --git a/vnpy/gateway/futu/__init__.py b/vnpy/gateway/futu/__init__.py index ba95dced..44b7509d 100644 --- a/vnpy/gateway/futu/__init__.py +++ b/vnpy/gateway/futu/__init__.py @@ -1 +1 @@ -from .futu_gateway import FutuGateway \ No newline at end of file +from .futu_gateway import FutuGateway diff --git a/vnpy/gateway/futu/futu_gateway.py b/vnpy/gateway/futu/futu_gateway.py index 6870ca58..0df5f9c7 100644 --- a/vnpy/gateway/futu/futu_gateway.py +++ b/vnpy/gateway/futu/futu_gateway.py @@ -22,7 +22,7 @@ from futu import ( StockQuoteHandlerBase, OrderBookHandlerBase, TradeOrderHandlerBase, - TradeDealHandlerBase + TradeDealHandlerBase, ) from vnpy.trader.gateway import BaseGateway @@ -36,14 +36,14 @@ from vnpy.trader.object import ( AccountData, SubscribeRequest, OrderRequest, - CancelRequest + CancelRequest, ) from vnpy.trader.event import EVENT_TIMER EXCHANGE_VT2FUTU = { Exchange.SMART: "US", Exchange.SEHK: "HK", - Exchange.HKFE: "HK_FUTURE" + Exchange.HKFE: "HK_FUTURE", } EXCHANGE_FUTU2VT = {v: k for k, v in EXCHANGE_VT2FUTU.items()} @@ -52,7 +52,7 @@ PRODUCT_VT2FUTU = { Product.INDEX: "IDX", Product.ETF: "ETF", Product.WARRANT: "WARRANT", - Product.BOND: "BOND" + Product.BOND: "BOND", } DIRECTION_VT2FUTU = {Direction.LONG: TrdSide.BUY, Direction.SHORT: TrdSide.SELL} @@ -79,10 +79,8 @@ class FutuGateway(BaseGateway): "password": "", "host": "127.0.0.1", "port": 11111, - "market": ["HK", - "US"], - "env": [TrdEnv.REAL, - TrdEnv.SIMULATE] + "market": ["HK", "US"], + "env": [TrdEnv.REAL, TrdEnv.SIMULATE], } def __init__(self, event_engine): @@ -126,7 +124,7 @@ class FutuGateway(BaseGateway): """ Query all data necessary. """ - sleep(2.0) # Wait 2 seconds till connection completed. + sleep(2.0) # Wait 2 seconds till connection completed. self.query_contract() self.query_trade() @@ -235,7 +233,7 @@ class FutuGateway(BaseGateway): def send_order(self, req): """""" side = DIRECTION_VT2FUTU[req.direction] - price_type = OrderType.NORMAL # Only limit order is supported. + price_type = OrderType.NORMAL # Only limit order is supported. # Set price adjustment mode to inside adjustment. if req.direction is Direction.LONG: @@ -251,7 +249,7 @@ class FutuGateway(BaseGateway): side, price_type, trd_env=self.env, - adjust_limit=adjust_limit + adjust_limit=adjust_limit, ) if code: @@ -268,11 +266,8 @@ class FutuGateway(BaseGateway): def cancel_order(self, req): """""" code, data = self.trade_ctx.modify_order( - ModifyOrderOp.CANCEL, - req.orderid, - 0, - 0, - trd_env=self.env) + ModifyOrderOp.CANCEL, req.orderid, 0, 0, trd_env=self.env + ) if code: self.write_log(f"撤单失败:{data}") @@ -295,7 +290,7 @@ class FutuGateway(BaseGateway): product=product, size=1, pricetick=0.001, - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.on_contract(contract) self.contracts[contract.vt_symbol] = contract @@ -314,11 +309,8 @@ class FutuGateway(BaseGateway): account = AccountData( accountid=f"{self.gateway_name}_{self.market}", balance=float(row["total_assets"]), - frozen=( - float(row["total_assets"]) - - float(row["avl_withdrawal_cash"]) - ), - gateway_name=self.gateway_name + frozen=(float(row["total_assets"]) - float(row["avl_withdrawal_cash"])), + gateway_name=self.gateway_name, ) self.on_account(account) @@ -340,7 +332,7 @@ class FutuGateway(BaseGateway): frozen=(float(row["qty"]) - float(row["can_sell_qty"])), price=float(row["pl_val"]), pnl=float(row["cost_price"]), - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.on_position(pos) @@ -386,7 +378,7 @@ class FutuGateway(BaseGateway): symbol=symbol, exchange=exchange, datetime=datetime.now(), - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.ticks[code] = tick @@ -405,10 +397,7 @@ class FutuGateway(BaseGateway): date = row["data_date"].replace("-", "") time = row["data_time"] - tick.datetime = datetime.strptime( - f"{date} {time}", - "%Y%m%d %H:%M:%S" - ) + tick.datetime = datetime.strptime(f"{date} {time}", "%Y%m%d %H:%M:%S") tick.open_price = row["open_price"] tick.high_price = row["high_price"] tick.low_price = row["low_price"] @@ -462,7 +451,7 @@ class FutuGateway(BaseGateway): traded=float(row["dealt_qty"]), status=STATUS_FUTU2VT[row["order_status"]], time=row["create_time"].split(" ")[-1], - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.on_order(order) @@ -487,7 +476,7 @@ class FutuGateway(BaseGateway): price=float(row["price"]), volume=float(row["qty"]), time=row["create_time"].split(" ")[-1], - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.on_trade(trade) diff --git a/vnpy/gateway/ib/__init__.py b/vnpy/gateway/ib/__init__.py index 9d915bff..d5e1f4d2 100644 --- a/vnpy/gateway/ib/__init__.py +++ b/vnpy/gateway/ib/__init__.py @@ -1 +1 @@ -from .ib_gateway import IbGateway \ No newline at end of file +from .ib_gateway import IbGateway diff --git a/vnpy/gateway/ib/ib_gateway.py b/vnpy/gateway/ib/ib_gateway.py index f3adc7fd..984acf2f 100644 --- a/vnpy/gateway/ib/ib_gateway.py +++ b/vnpy/gateway/ib/ib_gateway.py @@ -27,7 +27,7 @@ from vnpy.trader.object import ( AccountData, SubscribeRequest, OrderRequest, - CancelRequest + CancelRequest, ) from vnpy.trader.constant import ( Product, @@ -36,7 +36,7 @@ from vnpy.trader.constant import ( Exchange, Currency, Status, - OptionType + OptionType, ) PRICETYPE_VT2IB = {PriceType.LIMIT: "LMT", PriceType.MARKET: "MKT"} @@ -55,7 +55,7 @@ EXCHANGE_VT2IB = { Exchange.CME: "CME", Exchange.ICE: "ICE", Exchange.SEHK: "SEHK", - Exchange.HKFE: "HKFE" + Exchange.HKFE: "HKFE", } EXCHANGE_IB2VT = {v: k for k, v in EXCHANGE_VT2IB.items()} @@ -64,7 +64,7 @@ STATUS_IB2VT = { "Filled": Status.ALLTRADED, "Cancelled": Status.CANCELLED, "PendingSubmit": Status.SUBMITTING, - "PreSubmitted": Status.NOTTRADED + "PreSubmitted": Status.NOTTRADED, } PRODUCT_VT2IB = { @@ -72,7 +72,7 @@ PRODUCT_VT2IB = { Product.FOREX: "CASH", Product.SPOT: "CMDTY", Product.OPTION: "OPT", - Product.FUTURES: "FUT" + Product.FUTURES: "FUT", } PRODUCT_IB2VT = {v: k for k, v in PRODUCT_VT2IB.items()} @@ -91,7 +91,7 @@ TICKFIELD_IB2VT = { 7: "low_price", 8: "volume", 9: "pre_close", - 14: "open_price" + 14: "open_price", } ACCOUNTFIELD_IB2VT = { @@ -182,21 +182,21 @@ class IbApi(EWrapper): self.client = IbClient(self) self.thread = Thread(target=self.client.run) - def connectAck(self): # pylint: disable=invalid-name + def connectAck(self): # pylint: disable=invalid-name """ Callback when connection is established. """ self.status = True self.gateway.write_log("IB TWS连接成功") - def connectionClosed(self): # pylint: disable=invalid-name + def connectionClosed(self): # pylint: disable=invalid-name """ Callback when connection is closed. """ self.status = False self.gateway.write_log("IB TWS连接断开") - def nextValidId(self, orderId: int): # pylint: disable=invalid-name + def nextValidId(self, orderId: int): # pylint: disable=invalid-name """ Callback of next valid orderid. """ @@ -204,7 +204,7 @@ class IbApi(EWrapper): self.orderid = orderId - def currentTime(self, time: int): # pylint: disable=invalid-name + def currentTime(self, time: int): # pylint: disable=invalid-name """ Callback of current server time of IB. """ @@ -216,7 +216,9 @@ class IbApi(EWrapper): msg = f"服务器时间: {time_string}" self.gateway.write_log(msg) - def error(self, reqId: TickerId, errorCode: int, errorString: str): # pylint: disable=invalid-name + def error( + self, reqId: TickerId, errorCode: int, errorString: str + ): # pylint: disable=invalid-name """ Callback of error caused by specific request. """ @@ -226,11 +228,7 @@ class IbApi(EWrapper): self.gateway.write_log(msg) def tickPrice( # pylint: disable=invalid-name - self, - reqId: TickerId, - tickType: TickType, - price: float, - attrib: TickAttrib + self, reqId: TickerId, tickType: TickType, price: float, attrib: TickAttrib ): """ Callback of tick price update. @@ -257,7 +255,9 @@ class IbApi(EWrapper): tick.datetime = datetime.now() self.gateway.on_tick(copy(tick)) - def tickSize(self, reqId: TickerId, tickType: TickType, size: int): # pylint: disable=invalid-name + def tickSize( + self, reqId: TickerId, tickType: TickType, size: int + ): # pylint: disable=invalid-name """ Callback of tick volume update. """ @@ -272,7 +272,9 @@ class IbApi(EWrapper): self.gateway.on_tick(copy(tick)) - def tickString(self, reqId: TickerId, tickType: TickType, value: str): # pylint: disable=invalid-name + def tickString( + self, reqId: TickerId, tickType: TickType, value: str + ): # pylint: disable=invalid-name """ Callback of tick string update. """ @@ -287,36 +289,35 @@ class IbApi(EWrapper): self.gateway.on_tick(copy(tick)) def orderStatus( # pylint: disable=invalid-name - self, - orderId: OrderId, - status: str, - filled: float, - remaining: float, - avgFillPrice: float, - permId: int, - parentId: int, - lastFillPrice: float, - clientId: int, - whyHeld: str, - mktCapPrice: float + self, + orderId: OrderId, + status: str, + filled: float, + remaining: float, + avgFillPrice: float, + permId: int, + parentId: int, + lastFillPrice: float, + clientId: int, + whyHeld: str, + mktCapPrice: float, ): """ Callback of order status update. """ - super(IbApi, - self).orderStatus( - orderId, - status, - filled, - remaining, - avgFillPrice, - permId, - parentId, - lastFillPrice, - clientId, - whyHeld, - mktCapPrice - ) + super(IbApi, self).orderStatus( + orderId, + status, + filled, + remaining, + avgFillPrice, + permId, + parentId, + lastFillPrice, + clientId, + whyHeld, + mktCapPrice, + ) orderid = str(orderId) order = self.orders.get(orderid, None) @@ -326,11 +327,11 @@ class IbApi(EWrapper): self.gateway.on_order(copy(order)) def openOrder( # pylint: disable=invalid-name - self, - orderId: OrderId, - ib_contract: Contract, - ib_order: Order, - orderState: OrderState + self, + orderId: OrderId, + ib_contract: Contract, + ib_order: Order, + orderState: OrderState, ): """ Callback when opening new order. @@ -340,26 +341,19 @@ class IbApi(EWrapper): orderid = str(orderId) order = OrderData( symbol=ib_contract.conId, - exchange=EXCHANGE_IB2VT.get( - ib_contract.exchange, - ib_contract.exchange - ), + exchange=EXCHANGE_IB2VT.get(ib_contract.exchange, ib_contract.exchange), orderid=orderid, direction=DIRECTION_IB2VT[ib_order.action], price=ib_order.lmtPrice, volume=ib_order.totalQuantity, - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.orders[orderid] = order self.gateway.on_order(copy(order)) def updateAccountValue( # pylint: disable=invalid-name - self, - key: str, - val: str, - currency: str, - accountName: str + self, key: str, val: str, currency: str, accountName: str ): """ Callback of account update. @@ -372,54 +366,49 @@ class IbApi(EWrapper): accountid = f"{accountName}.{currency}" account = self.accounts.get(accountid, None) if not account: - account = AccountData( - accountid=accountid, - gateway_name=self.gateway_name - ) + account = AccountData(accountid=accountid, gateway_name=self.gateway_name) self.accounts[accountid] = account name = ACCOUNTFIELD_IB2VT[key] setattr(account, name, float(val)) def updatePortfolio( # pylint: disable=invalid-name - self, - contract: Contract, - position: float, - marketPrice: float, - marketValue: float, - averageCost: float, - unrealizedPNL: float, - realizedPNL: float, - accountName: str + self, + contract: Contract, + position: float, + marketPrice: float, + marketValue: float, + averageCost: float, + unrealizedPNL: float, + realizedPNL: float, + accountName: str, ): """ Callback of position update. """ - super(IbApi, - self).updatePortfolio( - contract, - position, - marketPrice, - marketValue, - averageCost, - unrealizedPNL, - realizedPNL, - accountName - ) + super(IbApi, self).updatePortfolio( + contract, + position, + marketPrice, + marketValue, + averageCost, + unrealizedPNL, + realizedPNL, + accountName, + ) pos = PositionData( symbol=contract.conId, - exchange=EXCHANGE_IB2VT.get(contract.exchange, - contract.exchange), + exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange), direction=DIRECTION_NET, volume=position, price=averageCost, pnl=unrealizedPNL, - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.gateway.on_position(pos) - def updateAccountTime(self, timeStamp: str): # pylint: disable=invalid-name + def updateAccountTime(self, timeStamp: str): # pylint: disable=invalid-name """ Callback of account update time. """ @@ -427,7 +416,9 @@ class IbApi(EWrapper): for account in self.accounts.values(): self.gateway.on_account(copy(account)) - def contractDetails(self, reqId: int, contractDetails: ContractDetails): # pylint: disable=invalid-name + def contractDetails( + self, reqId: int, contractDetails: ContractDetails + ): # pylint: disable=invalid-name """ Callback of contract data update. """ @@ -443,20 +434,21 @@ class IbApi(EWrapper): contract = ContractData( symbol=ib_symbol, - exchange=EXCHANGE_IB2VT.get(ib_exchange, - ib_exchange), + exchange=EXCHANGE_IB2VT.get(ib_exchange, ib_exchange), name=contractDetails.longName, product=PRODUCT_IB2VT[ib_product], size=ib_size, pricetick=contractDetails.minTick, - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.gateway.on_contract(contract) self.contracts[contract.vt_symbol] = contract - def execDetails(self, reqId: int, contract: Contract, execution: Execution): # pylint: disable=invalid-name + def execDetails( + self, reqId: int, contract: Contract, execution: Execution + ): # pylint: disable=invalid-name """ Callback of trade data update. """ @@ -465,21 +457,19 @@ class IbApi(EWrapper): today_date = datetime.now().strftime("%Y%m%d") trade = TradeData( symbol=contract.conId, - exchange=EXCHANGE_IB2VT.get(contract.exchange, - contract.exchange), + exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange), orderid=str(execution.orderId), tradeid=str(execution.execId), direction=DIRECTION_IB2VT[execution.side], price=execution.price, volume=execution.shares, - time=datetime.strptime(execution.time, - "%Y%m%d %H:%M:%S"), - gateway_name=self.gateway_name + time=datetime.strptime(execution.time, "%Y%m%d %H:%M:%S"), + gateway_name=self.gateway_name, ) self.gateway.on_trade(trade) - def managedAccounts(self, accountsList: str): # pylint: disable=invalid-name + def managedAccounts(self, accountsList: str): # pylint: disable=invalid-name """ Callback of all sub accountid. """ @@ -497,11 +487,7 @@ class IbApi(EWrapper): self.clientid = setting["clientid"] - self.client.connect( - setting["host"], - setting["port"], - setting["clientid"] - ) + self.client.connect(setting["host"], setting["port"], setting["clientid"]) self.thread.start() @@ -544,7 +530,7 @@ class IbApi(EWrapper): symbol=req.symbol, exchange=req.exchange, datetime=datetime.now(), - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) self.ticks[self.reqid] = tick self.tick_exchange[self.reqid] = req.exchange diff --git a/vnpy/trader/app.py b/vnpy/trader/app.py index e78dc771..ccd4900a 100644 --- a/vnpy/trader/app.py +++ b/vnpy/trader/app.py @@ -7,10 +7,11 @@ class BaseApp(ABC): """ Absstract class for app. """ - app_name = "" # Unique name used for creating engine and widget - app_module = "" # App module string used in import_module - app_path = "" # Absolute path of app folder - display_name = "" # Name for display on the menu. - engine_class = None # App engine class - widget_name = "" # Class name of app widget - icon_name = "" # Icon file name of app widget + + app_name = "" # Unique name used for creating engine and widget + app_module = "" # App module string used in import_module + app_path = "" # Absolute path of app folder + display_name = "" # Name for display on the menu. + engine_class = None # App engine class + widget_name = "" # Class name of app widget + icon_name = "" # Icon file name of app widget diff --git a/vnpy/trader/constant.py b/vnpy/trader/constant.py index 99a39a14..c68105aa 100644 --- a/vnpy/trader/constant.py +++ b/vnpy/trader/constant.py @@ -9,6 +9,7 @@ class Direction(Enum): """ Direction of order/trade/position. """ + LONG = "多" SHORT = "空" NET = "净" @@ -18,6 +19,7 @@ class Offset(Enum): """ Offset of order/trade. """ + NONE = "" OPEN = "开" CLOSE = "平" @@ -29,6 +31,7 @@ class Status(Enum): """ Order status. """ + SUBMITTING = "提交中" NOTTRADED = "未成交" PARTTRADED = "部分成交" @@ -41,6 +44,7 @@ class Product(Enum): """ Product class. """ + EQUITY = "股票" FUTURES = "期货" OPTION = "期权" @@ -56,6 +60,7 @@ class PriceType(Enum): """ Order price type. """ + LIMIT = "限价" MARKET = "市价" FAK = "FAK" @@ -66,6 +71,7 @@ class OptionType(Enum): """ Option type. """ + CALL = "看涨期权" PUT = "看跌期权" @@ -74,6 +80,7 @@ class Exchange(Enum): """ Exchange. """ + # Chinese CFFEX = "CFFEX" SHFE = "SHFE" @@ -102,6 +109,7 @@ class Currency(Enum): """ Currency. """ + USD = "USD" HKD = "HKD" CNY = "CNY" @@ -111,4 +119,4 @@ class Interval(Enum): MINUTE = "1m" HOUR = "1h" DAILY = "d" - WEEKLY = "w" \ No newline at end of file + WEEKLY = "w" diff --git a/vnpy/trader/database.py b/vnpy/trader/database.py index 2ae2e290..49b1826a 100644 --- a/vnpy/trader/database.py +++ b/vnpy/trader/database.py @@ -6,7 +6,7 @@ from peewee import ( CharField, DateTimeField, FloatField, - IntegerField + IntegerField, ) from .utility import get_temp_path @@ -23,6 +23,7 @@ class DbBarData(Model): Index is defined unique with vt_symbol, interval and datetime. """ + symbol = CharField() exchange = CharField() datetime = DateTimeField() @@ -39,7 +40,7 @@ class DbBarData(Model): class Meta: database = DB - indexes = ((('vt_symbol', 'interval', 'datetime'), True),) + indexes = ((("vt_symbol", "interval", "datetime"), True),) @staticmethod def from_bar(bar: BarData): @@ -76,7 +77,7 @@ class DbBarData(Model): high_price=high_price, low_price=low_price, close_price=close_price, - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) return bar @@ -87,6 +88,7 @@ class DbTickData(Model): Index is defined unique with vt_symbol, interval and datetime. """ + symbol = CharField() exchange = CharField() datetime = DateTimeField() @@ -132,7 +134,7 @@ class DbTickData(Model): class Meta: database = DB - indexes = ((('vt_symbol', 'datetime'), True),) + indexes = ((("vt_symbol", "datetime"), True),) @staticmethod def from_tick(tick: TickData): @@ -208,7 +210,7 @@ class DbTickData(Model): ask_price_1=self.ask_price_1, bid_volume_1=self.bid_volume_1, ask_volume_1=self.ask_volume_1, - gateway_name=self.gateway_name + gateway_name=self.gateway_name, ) if self.bid_price_2: diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index dc9c6cdf..b742abd5 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -19,7 +19,7 @@ from .event import ( EVENT_TRADE, EVENT_POSITION, EVENT_ACCOUNT, - EVENT_CONTRACT + EVENT_CONTRACT, ) from .object import LogData, SubscribeRequest, OrderRequest, CancelRequest from .utility import Singleton, get_temp_path @@ -180,10 +180,7 @@ class BaseEngine(ABC): """ def __init__( - self, - main_engine: MainEngine, - event_engine: EventEngine, - engine_name: str + self, main_engine: MainEngine, event_engine: EventEngine, engine_name: str ): """""" self.main_engine = main_engine @@ -211,9 +208,7 @@ class LogEngine(BaseEngine): self.level = SETTINGS["log.level"] self.logger = logging.getLogger("VN Trader") - self.formatter = logging.Formatter( - "%(asctime)s %(levelname)s: %(message)s" - ) + self.formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s") self.add_null_handler() @@ -249,7 +244,7 @@ class LogEngine(BaseEngine): filename = f"vt_{today_date}.log" file_path = get_temp_path(filename) - file_handler = logging.FileHandler(file_path, mode='w', encoding='utf8') + file_handler = logging.FileHandler(file_path, mode="w", encoding="utf8") file_handler.setLevel(self.level) file_handler.setFormatter(self.formatter) self.logger.addHandler(file_handler) @@ -421,7 +416,7 @@ class OmsEngine(BaseEngine): """ return list(self.contracts.values()) - def get_all_active_orders(self, vt_symbol: str = ''): + def get_all_active_orders(self, vt_symbol: str = ""): """ Get all active orders by vt_symbol. @@ -431,7 +426,8 @@ class OmsEngine(BaseEngine): return list(self.active_orders.values()) else: active_orders = [ - order for order in self.active_orders.values() + order + for order in self.active_orders.values() if order.vt_symbol == vt_symbol ] return active_orders @@ -476,12 +472,10 @@ class EmailEngine(BaseEngine): try: msg = self.queue.get(block=True, timeout=1) - with smtplib.SMTP_SSL(SETTINGS["email.server"], - SETTINGS["email.port"]) as smtp: - smtp.login( - SETTINGS["email.username"], - SETTINGS["email.password"] - ) + with smtplib.SMTP_SSL( + SETTINGS["email.server"], SETTINGS["email.port"] + ) as smtp: + smtp.login(SETTINGS["email.username"], SETTINGS["email.password"]) smtp.send_message(msg) except Empty: pass diff --git a/vnpy/trader/event.py b/vnpy/trader/event.py index c679a971..6c92f948 100644 --- a/vnpy/trader/event.py +++ b/vnpy/trader/event.py @@ -4,10 +4,10 @@ Event type string used in VN Trader. from vnpy.event import EVENT_TIMER -EVENT_TICK = 'eTick.' -EVENT_TRADE = 'eTrade.' -EVENT_ORDER = 'eOrder.' -EVENT_POSITION = 'ePosition.' -EVENT_ACCOUNT = 'eAccount.' -EVENT_CONTRACT = 'eContract.' -EVENT_LOG = 'eLog' +EVENT_TICK = "eTick." +EVENT_TRADE = "eTrade." +EVENT_ORDER = "eOrder." +EVENT_POSITION = "ePosition." +EVENT_ACCOUNT = "eAccount." +EVENT_CONTRACT = "eContract." +EVENT_LOG = "eLog" diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index 5104ba31..a5916281 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -14,7 +14,7 @@ from .event import ( EVENT_ACCOUNT, EVENT_POSITION, EVENT_LOG, - EVENT_CONTRACT + EVENT_CONTRACT, ) from .object import ( TickData, @@ -26,7 +26,7 @@ from .object import ( ContractData, SubscribeRequest, OrderRequest, - CancelRequest + CancelRequest, ) @@ -163,4 +163,4 @@ class BaseGateway(ABC): """ Return default setting dict. """ - return self.default_setting \ No newline at end of file + return self.default_setting diff --git a/vnpy/trader/object.py b/vnpy/trader/object.py index 70be6cb5..c0243239 100644 --- a/vnpy/trader/object.py +++ b/vnpy/trader/object.py @@ -17,6 +17,7 @@ class BaseData: Any data object needs a gateway_name as source or destination and should inherit base data. """ + gateway_name: str @@ -28,6 +29,7 @@ class TickData(BaseData): * orderbook snapshot * intraday market statistics. """ + symbol: str exchange: Exchange datetime: datetime @@ -78,6 +80,7 @@ class BarData(BaseData): """ Candlestick bar data of a certain trading period. """ + symbol: str exchange: Exchange datetime: datetime @@ -100,6 +103,7 @@ class OrderData(BaseData): Order data contains information for tracking lastest status of a specific order. """ + symbol: str exchange: Exchange orderid: str @@ -131,9 +135,7 @@ class OrderData(BaseData): Create cancel request object from order. """ req = CancelRequest( - orderid=self.orderid, - symbol=self.symbol, - exchange=self.exchange + orderid=self.orderid, symbol=self.symbol, exchange=self.exchange ) return req @@ -144,6 +146,7 @@ class TradeData(BaseData): Trade data contains information of a fill of an order. One order can have several trade fills. """ + symbol: str exchange: Exchange orderid: str @@ -167,6 +170,7 @@ class PositionData(BaseData): """ Positon data is used for tracking each individual position holding. """ + symbol: str exchange: Exchange direction: Direction @@ -188,6 +192,7 @@ class AccountData(BaseData): Account data contains information about balance, frozen and available. """ + accountid: str balance: float = 0 @@ -204,6 +209,7 @@ class LogData(BaseData): """ Log data is used for recording log messages on GUI or in log files. """ + msg: str level: int = INFO @@ -217,6 +223,7 @@ class ContractData(BaseData): """ Contract data contains basic information about each contract traded. """ + symbol: str exchange: Exchange name: str @@ -225,8 +232,8 @@ class ContractData(BaseData): pricetick: float option_strike: float = 0 - option_underlying: str = '' # vt_symbol of underlying contract - option_type: str = '' + option_underlying: str = "" # vt_symbol of underlying contract + option_type: str = "" option_expiry: datetime = None def __post_init__(self): @@ -239,6 +246,7 @@ class SubscribeRequest: """ Request sending to specific gateway for subscribing tick data update. """ + symbol: str exchange: Exchange @@ -252,6 +260,7 @@ class OrderRequest: """ Request sending to specific gateway for creating a new order. """ + symbol: str exchange: Exchange direction: Direction @@ -276,7 +285,7 @@ class OrderRequest: offset=self.offset, price=self.price, volume=self.volume, - gateway_name=gateway_name + gateway_name=gateway_name, ) return order @@ -286,6 +295,7 @@ class CancelRequest: """ Request sending to specific gateway for canceling an existing order. """ + orderid: str symbol: str exchange: Exchange diff --git a/vnpy/trader/setting.py b/vnpy/trader/setting.py index f91d121d..f5aa754d 100644 --- a/vnpy/trader/setting.py +++ b/vnpy/trader/setting.py @@ -16,5 +16,5 @@ SETTINGS = { "email.username": "", "email.password": "", "email.sender": "", - "email.receiver": "" -} \ No newline at end of file + "email.receiver": "", +} diff --git a/vnpy/trader/ui/__init__.py b/vnpy/trader/ui/__init__.py index 665d6a59..d1424ec2 100644 --- a/vnpy/trader/ui/__init__.py +++ b/vnpy/trader/ui/__init__.py @@ -14,13 +14,8 @@ from ..utility import get_icon_path def excepthook(exctype, value, tb): """异常捕捉钩子""" - msg = ''.join(traceback.format_exception(exctype, value, tb)) - QtWidgets.QMessageBox.critical( - None, - u'Exception', - msg, - QtWidgets.QMessageBox.Ok - ) + msg = "".join(traceback.format_exception(exctype, value, tb)) + QtWidgets.QMessageBox.critical(None, u"Exception", msg, QtWidgets.QMessageBox.Ok) def create_qapp(): @@ -38,9 +33,7 @@ def create_qapp(): icon = QtGui.QIcon(get_icon_path(__file__, "vnpy.ico")) qapp.setWindowIcon(icon) - if 'Windows' in platform.uname(): - ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID( - 'VN Trader' - ) + if "Windows" in platform.uname(): + ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("VN Trader") - return qapp \ No newline at end of file + return qapp diff --git a/vnpy/trader/ui/mainwindow.py b/vnpy/trader/ui/mainwindow.py index f436efef..65f6f699 100644 --- a/vnpy/trader/ui/mainwindow.py +++ b/vnpy/trader/ui/mainwindow.py @@ -22,7 +22,7 @@ from .widget import ( TradingWidget, ActiveOrderMonitor, ContractManager, - AboutDialog + AboutDialog, ) @@ -54,14 +54,30 @@ class MainWindow(QtWidgets.QMainWindow): def init_dock(self): """""" - trading_widget, trading_dock = self.create_dock(TradingWidget, "交易", QtCore.Qt.LeftDockWidgetArea) - tick_widget, tick_dock = self.create_dock(TickMonitor, "行情", QtCore.Qt.RightDockWidgetArea) - order_widget, order_dock = self.create_dock(OrderMonitor, "委托", QtCore.Qt.RightDockWidgetArea) - active_widget, active_dock = self.create_dock(ActiveOrderMonitor, "活动", QtCore.Qt.RightDockWidgetArea) - trade_widget, trade_dock = self.create_dock(TradeMonitor, "成交", QtCore.Qt.RightDockWidgetArea) - log_widget, log_dock = self.create_dock(LogMonitor, "日志", QtCore.Qt.BottomDockWidgetArea) - account_widget, account_dock = self.create_dock(AccountMonitor, "资金", QtCore.Qt.BottomDockWidgetArea) - position_widget, position_dock = self.create_dock(PositionMonitor, "持仓", QtCore.Qt.BottomDockWidgetArea) + trading_widget, trading_dock = self.create_dock( + TradingWidget, "交易", QtCore.Qt.LeftDockWidgetArea + ) + tick_widget, tick_dock = self.create_dock( + TickMonitor, "行情", QtCore.Qt.RightDockWidgetArea + ) + order_widget, order_dock = self.create_dock( + OrderMonitor, "委托", QtCore.Qt.RightDockWidgetArea + ) + active_widget, active_dock = self.create_dock( + ActiveOrderMonitor, "活动", QtCore.Qt.RightDockWidgetArea + ) + trade_widget, trade_dock = self.create_dock( + TradeMonitor, "成交", QtCore.Qt.RightDockWidgetArea + ) + log_widget, log_dock = self.create_dock( + LogMonitor, "日志", QtCore.Qt.BottomDockWidgetArea + ) + account_widget, account_dock = self.create_dock( + AccountMonitor, "资金", QtCore.Qt.BottomDockWidgetArea + ) + position_widget, position_dock = self.create_dock( + PositionMonitor, "持仓", QtCore.Qt.BottomDockWidgetArea + ) self.tabifyDockWidget(active_dock, order_dock) @@ -96,52 +112,31 @@ class MainWindow(QtWidgets.QMainWindow): func = partial(self.open_widget, widget_class, app.app_name) icon_path = str(app.app_path.joinpath("ui", app.icon_name)) - self.add_menu_action( - app_menu, - f"打开{app.display_name}", - icon_path, - func - ) + self.add_menu_action(app_menu, f"打开{app.display_name}", icon_path, func) # Help menu self.add_menu_action( help_menu, "查询合约", "contract.ico", - partial(self.open_widget, - ContractManager, - "contract") + partial(self.open_widget, ContractManager, "contract"), ) self.add_menu_action( - help_menu, - "还原窗口", - "restore.ico", - self.restore_window_setting + help_menu, "还原窗口", "restore.ico", self.restore_window_setting ) - self.add_menu_action( - help_menu, - "测试邮件", - "email.ico", - self.send_test_email - ) + self.add_menu_action(help_menu, "测试邮件", "email.ico", self.send_test_email) self.add_menu_action( help_menu, "关于", "about.ico", - partial(self.open_widget, - AboutDialog, - "about") + partial(self.open_widget, AboutDialog, "about"), ) def add_menu_action( - self, - menu: QtWidgets.QMenu, - action_name: str, - icon_name: str, - func: Callable + self, menu: QtWidgets.QMenu, action_name: str, icon_name: str, func: Callable ): """""" icon = QtGui.QIcon(get_icon_path(__file__, icon_name)) @@ -152,12 +147,7 @@ class MainWindow(QtWidgets.QMainWindow): menu.addAction(action) - def create_dock( - self, - widget_class: QtWidgets.QWidget, - name: str, - area: int - ): + def create_dock(self, widget_class: QtWidgets.QWidget, name: str, area: int): """ Initialize a dock widget. """ @@ -189,7 +179,7 @@ class MainWindow(QtWidgets.QMainWindow): "退出", "确认退出?", QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No, - QtWidgets.QMessageBox.No + QtWidgets.QMessageBox.No, ) if reply == QtWidgets.QMessageBox.Yes: diff --git a/vnpy/trader/ui/widget.py b/vnpy/trader/ui/widget.py index d472aa51..d50c400a 100644 --- a/vnpy/trader/ui/widget.py +++ b/vnpy/trader/ui/widget.py @@ -18,7 +18,7 @@ from ..event import ( EVENT_ACCOUNT, EVENT_POSITION, EVENT_CONTRACT, - EVENT_LOG + EVENT_LOG, ) from ..object import SubscribeRequest, OrderRequest, CancelRequest from ..utility import load_setting, save_setting @@ -299,22 +299,19 @@ class BaseMonitor(QtWidgets.QTableWidget): """ Resize all columns according to contents. """ - self.horizontalHeader().resizeSections( - QtWidgets.QHeaderView.ResizeToContents - ) + self.horizontalHeader().resizeSections(QtWidgets.QHeaderView.ResizeToContents) def save_csv(self): """ Save table data into a csv file """ - path, _ = QtWidgets.QFileDialog.getSaveFileName(self, "保存数据", "", - "CSV(*.csv)") + path, _ = QtWidgets.QFileDialog.getSaveFileName(self, "保存数据", "", "CSV(*.csv)") if not path: return with open(path, "w") as f: - writer = csv.writer(f, lineterminator='\n') + writer = csv.writer(f, lineterminator="\n") writer.writerow(self.headers.keys()) @@ -339,81 +336,26 @@ class TickMonitor(BaseMonitor): """ Monitor for tick data. """ + event_type = EVENT_TICK data_key = "vt_symbol" sorting = True headers = { - "symbol": { - "display": "代码", - "cell": BaseCell, - "update": False - }, - "exchange": { - "display": "交易所", - "cell": EnumCell, - "update": False - }, - "name": { - "display": "名称", - "cell": BaseCell, - "update": True - }, - "last_price": { - "display": "最新价", - "cell": BaseCell, - "update": True - }, - "volume": { - "display": "成交量", - "cell": BaseCell, - "update": True - }, - "open_price": { - "display": "开盘价", - "cell": BaseCell, - "update": True - }, - "high_price": { - "display": "最高价", - "cell": BaseCell, - "update": True - }, - "low_price": { - "display": "最低价", - "cell": BaseCell, - "update": True - }, - "bid_price_1": { - "display": "买1价", - "cell": BidCell, - "update": True - }, - "bid_volume_1": { - "display": "买1量", - "cell": BidCell, - "update": True - }, - "ask_price_1": { - "display": "卖1价", - "cell": AskCell, - "update": True - }, - "ask_volume_1": { - "display": "卖1量", - "cell": AskCell, - "update": True - }, - "datetime": { - "display": "时间", - "cell": TimeCell, - "update": True - }, - "gateway_name": { - "display": "接口", - "cell": BaseCell, - "update": False - } + "symbol": {"display": "代码", "cell": BaseCell, "update": False}, + "exchange": {"display": "交易所", "cell": EnumCell, "update": False}, + "name": {"display": "名称", "cell": BaseCell, "update": True}, + "last_price": {"display": "最新价", "cell": BaseCell, "update": True}, + "volume": {"display": "成交量", "cell": BaseCell, "update": True}, + "open_price": {"display": "开盘价", "cell": BaseCell, "update": True}, + "high_price": {"display": "最高价", "cell": BaseCell, "update": True}, + "low_price": {"display": "最低价", "cell": BaseCell, "update": True}, + "bid_price_1": {"display": "买1价", "cell": BidCell, "update": True}, + "bid_volume_1": {"display": "买1量", "cell": BidCell, "update": True}, + "ask_price_1": {"display": "卖1价", "cell": AskCell, "update": True}, + "ask_volume_1": {"display": "卖1量", "cell": AskCell, "update": True}, + "datetime": {"display": "时间", "cell": TimeCell, "update": True}, + "gateway_name": {"display": "接口", "cell": BaseCell, "update": False}, } @@ -421,26 +363,15 @@ class LogMonitor(BaseMonitor): """ Monitor for log data. """ + event_type = EVENT_LOG data_key = "" sorting = False headers = { - "time": { - "display": "时间", - "cell": TimeCell, - "update": False - }, - "msg": { - "display": "信息", - "cell": MsgCell, - "update": False - }, - "gateway_name": { - "display": "接口", - "cell": BaseCell, - "update": False - } + "time": {"display": "时间", "cell": TimeCell, "update": False}, + "msg": {"display": "信息", "cell": MsgCell, "update": False}, + "gateway_name": {"display": "接口", "cell": BaseCell, "update": False}, } @@ -448,61 +379,22 @@ class TradeMonitor(BaseMonitor): """ Monitor for trade data. """ + event_type = EVENT_TRADE data_key = "" sorting = True headers = { - "tradeid": { - "display": "成交号 ", - "cell": BaseCell, - "update": False - }, - "orderid": { - "display": "委托号", - "cell": BaseCell, - "update": False - }, - "symbol": { - "display": "代码", - "cell": BaseCell, - "update": False - }, - "exchange": { - "display": "交易所", - "cell": EnumCell, - "update": False - }, - "direction": { - "display": "方向", - "cell": DirectionCell, - "update": False - }, - "offset": { - "display": "开平", - "cell": EnumCell, - "update": False - }, - "price": { - "display": "价格", - "cell": BaseCell, - "update": False - }, - "volume": { - "display": "数量", - "cell": BaseCell, - "update": False - }, - "time": { - "display": "时间", - "cell": BaseCell, - "update": False - }, - "gateway_name": { - "display": "接口", - "cell": BaseCell, - "update": False - } + "tradeid": {"display": "成交号 ", "cell": BaseCell, "update": False}, + "orderid": {"display": "委托号", "cell": BaseCell, "update": False}, + "symbol": {"display": "代码", "cell": BaseCell, "update": False}, + "exchange": {"display": "交易所", "cell": EnumCell, "update": False}, + "direction": {"display": "方向", "cell": DirectionCell, "update": False}, + "offset": {"display": "开平", "cell": EnumCell, "update": False}, + "price": {"display": "价格", "cell": BaseCell, "update": False}, + "volume": {"display": "数量", "cell": BaseCell, "update": False}, + "time": {"display": "时间", "cell": BaseCell, "update": False}, + "gateway_name": {"display": "接口", "cell": BaseCell, "update": False}, } @@ -510,66 +402,23 @@ class OrderMonitor(BaseMonitor): """ Monitor for order data. """ + event_type = EVENT_ORDER data_key = "vt_orderid" sorting = True headers = { - "orderid": { - "display": "委托号", - "cell": BaseCell, - "update": False - }, - "symbol": { - "display": "代码", - "cell": BaseCell, - "update": False - }, - "exchange": { - "display": "交易所", - "cell": EnumCell, - "update": False - }, - "direction": { - "display": "方向", - "cell": DirectionCell, - "update": False - }, - "offset": { - "display": "开平", - "cell": EnumCell, - "update": False - }, - "price": { - "display": "价格", - "cell": BaseCell, - "update": False - }, - "volume": { - "display": "总数量", - "cell": BaseCell, - "update": True - }, - "traded": { - "display": "已成交", - "cell": BaseCell, - "update": True - }, - "status": { - "display": "状态", - "cell": EnumCell, - "update": True - }, - "time": { - "display": "时间", - "cell": BaseCell, - "update": True - }, - "gateway_name": { - "display": "接口", - "cell": BaseCell, - "update": False - } + "orderid": {"display": "委托号", "cell": BaseCell, "update": False}, + "symbol": {"display": "代码", "cell": BaseCell, "update": False}, + "exchange": {"display": "交易所", "cell": EnumCell, "update": False}, + "direction": {"display": "方向", "cell": DirectionCell, "update": False}, + "offset": {"display": "开平", "cell": EnumCell, "update": False}, + "price": {"display": "价格", "cell": BaseCell, "update": False}, + "volume": {"display": "总数量", "cell": BaseCell, "update": True}, + "traded": {"display": "已成交", "cell": BaseCell, "update": True}, + "status": {"display": "状态", "cell": EnumCell, "update": True}, + "time": {"display": "时间", "cell": BaseCell, "update": True}, + "gateway_name": {"display": "接口", "cell": BaseCell, "update": False}, } def init_ui(self): @@ -594,51 +443,20 @@ class PositionMonitor(BaseMonitor): """ Monitor for position data. """ + event_type = EVENT_POSITION data_key = "vt_positionid" sorting = True headers = { - "symbol": { - "display": "代码", - "cell": BaseCell, - "update": False - }, - "exchange": { - "display": "交易所", - "cell": EnumCell, - "update": False - }, - "direction": { - "display": "方向", - "cell": DirectionCell, - "update": False - }, - "volume": { - "display": "数量", - "cell": BaseCell, - "update": True - }, - "frozen": { - "display": "冻结", - "cell": BaseCell, - "update": True - }, - "price": { - "display": "均价", - "cell": BaseCell, - "update": False - }, - "pnl": { - "display": "盈亏", - "cell": PnlCell, - "update": True - }, - "gateway_name": { - "display": "接口", - "cell": BaseCell, - "update": False - } + "symbol": {"display": "代码", "cell": BaseCell, "update": False}, + "exchange": {"display": "交易所", "cell": EnumCell, "update": False}, + "direction": {"display": "方向", "cell": DirectionCell, "update": False}, + "volume": {"display": "数量", "cell": BaseCell, "update": True}, + "frozen": {"display": "冻结", "cell": BaseCell, "update": True}, + "price": {"display": "均价", "cell": BaseCell, "update": False}, + "pnl": {"display": "盈亏", "cell": PnlCell, "update": True}, + "gateway_name": {"display": "接口", "cell": BaseCell, "update": False}, } @@ -646,36 +464,17 @@ class AccountMonitor(BaseMonitor): """ Monitor for account data. """ + event_type = EVENT_ACCOUNT data_key = "vt_accountid" sorting = True headers = { - "accountid": { - "display": "账号", - "cell": BaseCell, - "update": False - }, - "balance": { - "display": "余额", - "cell": BaseCell, - "update": True - }, - "frozen": { - "display": "冻结", - "cell": BaseCell, - "update": True - }, - "available": { - "display": "可用", - "cell": BaseCell, - "update": True - }, - "gateway_name": { - "display": "接口", - "cell": BaseCell, - "update": False - } + "accountid": {"display": "账号", "cell": BaseCell, "update": False}, + "balance": {"display": "余额", "cell": BaseCell, "update": True}, + "frozen": {"display": "冻结", "cell": BaseCell, "update": True}, + "available": {"display": "可用", "cell": BaseCell, "update": True}, + "gateway_name": {"display": "接口", "cell": BaseCell, "update": False}, } @@ -701,9 +500,7 @@ class ConnectDialog(QtWidgets.QDialog): self.setWindowTitle(f"连接{self.gateway_name}") # Default setting provides field name, field data type and field default value. - default_setting = self.main_engine.get_default_setting( - self.gateway_name - ) + default_setting = self.main_engine.get_default_setting(self.gateway_name) # Saved setting provides field data used last time. loaded_setting = load_setting(self.filename) @@ -732,7 +529,7 @@ class ConnectDialog(QtWidgets.QDialog): form.addRow(f"{field_name} <{field_type.__name__}>", widget) self.widgets[field_name] = (widget, field_type) - button = QtWidgets.QPushButton(u"连接") + button = QtWidgets.QPushButton("连接") button.clicked.connect(self.connect) form.addRow(button) @@ -792,18 +589,13 @@ class TradingWidget(QtWidgets.QWidget): self.name_line.setReadOnly(True) self.direction_combo = QtWidgets.QComboBox() - self.direction_combo.addItems( - [Direction.LONG.value, - Direction.SHORT.value] - ) + self.direction_combo.addItems([Direction.LONG.value, Direction.SHORT.value]) self.offset_combo = QtWidgets.QComboBox() self.offset_combo.addItems([offset.value for offset in Offset]) self.price_type_combo = QtWidgets.QComboBox() - self.price_type_combo.addItems( - [price_type.value for price_type in PriceType] - ) + self.price_type_combo.addItems([price_type.value for price_type in PriceType]) double_validator = QtGui.QDoubleValidator() double_validator.setBottom(0) @@ -846,26 +638,11 @@ class TradingWidget(QtWidgets.QWidget): self.bp4_label = self.create_label(bid_color) self.bp5_label = self.create_label(bid_color) - self.bv1_label = self.create_label( - bid_color, - alignment=QtCore.Qt.AlignRight - ) - self.bv2_label = self.create_label( - bid_color, - alignment=QtCore.Qt.AlignRight - ) - self.bv3_label = self.create_label( - bid_color, - alignment=QtCore.Qt.AlignRight - ) - self.bv4_label = self.create_label( - bid_color, - alignment=QtCore.Qt.AlignRight - ) - self.bv5_label = self.create_label( - bid_color, - alignment=QtCore.Qt.AlignRight - ) + self.bv1_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight) + self.bv2_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight) + self.bv3_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight) + self.bv4_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight) + self.bv5_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight) self.ap1_label = self.create_label(ask_color) self.ap2_label = self.create_label(ask_color) @@ -873,26 +650,11 @@ class TradingWidget(QtWidgets.QWidget): self.ap4_label = self.create_label(ask_color) self.ap5_label = self.create_label(ask_color) - self.av1_label = self.create_label( - ask_color, - alignment=QtCore.Qt.AlignRight - ) - self.av2_label = self.create_label( - ask_color, - alignment=QtCore.Qt.AlignRight - ) - self.av3_label = self.create_label( - ask_color, - alignment=QtCore.Qt.AlignRight - ) - self.av4_label = self.create_label( - ask_color, - alignment=QtCore.Qt.AlignRight - ) - self.av5_label = self.create_label( - ask_color, - alignment=QtCore.Qt.AlignRight - ) + self.av1_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight) + self.av2_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight) + self.av3_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight) + self.av4_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight) + self.av5_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight) self.lp_label = self.create_label() self.return_label = self.create_label(alignment=QtCore.Qt.AlignRight) @@ -916,11 +678,7 @@ class TradingWidget(QtWidgets.QWidget): vbox.addLayout(form2) self.setLayout(vbox) - def create_label( - self, - color: str = "", - alignment: int = QtCore.Qt.AlignLeft - ): + def create_label(self, color: str = "", alignment: int = QtCore.Qt.AlignLeft): """ Create label with certain font color. """ @@ -992,7 +750,7 @@ class TradingWidget(QtWidgets.QWidget): contract = self.main_engine.get_contract(vt_symbol) if not contract: self.name_line.setText("") - gateway_name = (self.gateway_combo.currentText()) + gateway_name = self.gateway_combo.currentText() else: self.name_line.setText(contract.name) gateway_name = contract.gateway_name @@ -1067,7 +825,7 @@ class TradingWidget(QtWidgets.QWidget): price_type=PriceType(str(self.price_type_combo.currentText())), volume=volume, price=price, - offset=Offset(str(self.offset_combo.currentText())) + offset=Offset(str(self.offset_combo.currentText())), ) gateway_name = str(self.gateway_combo.currentText()) @@ -1118,7 +876,7 @@ class ContractManager(QtWidgets.QWidget): "product": "合约分类", "size": "合约乘数", "pricetick": "价格跳动", - "gateway_name": "交易接口" + "gateway_name": "交易接口", } def __init__(self, main_engine, event_engine): @@ -1171,8 +929,7 @@ class ContractManager(QtWidgets.QWidget): all_contracts = self.main_engine.get_all_contracts() if flt: contracts = [ - contract for contract in all_contracts - if flt in contract.vt_symbol + contract for contract in all_contracts if flt in contract.vt_symbol ] else: contracts = all_contracts diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 26b32e27..a21e376b 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -14,14 +14,13 @@ class Singleton(type): __metaclass__ = Singleton """ + _instances = {} def __call__(cls, *args, **kwargs): """""" if cls not in cls._instances: - cls._instances[cls] = super(Singleton, - cls).__call__(*args, - **kwargs) + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] @@ -39,7 +38,7 @@ def get_temp_path(filename: str): Get path for temp file with filename. """ trader_path = get_trader_path() - temp_path = trader_path.joinpath('.vntrader') + temp_path = trader_path.joinpath(".vntrader") if not temp_path.exists(): temp_path.mkdir()