[Mod]change formatting tools to black

This commit is contained in:
vn.py 2019-01-26 17:24:38 +08:00
parent f9fd309098
commit 5d4e975ff4
32 changed files with 634 additions and 1077 deletions

View File

@ -2,4 +2,8 @@ PyQt5
qdarkstyle qdarkstyle
futu-api futu-api
websocket-client websocket-client
peewee peewee
numpy
pandas
matplotlib
seaborn

View File

@ -1,4 +1,4 @@
import unittest import unittest
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -5,20 +5,21 @@ from yapf.yapflib.yapf_api import FormatFile
logger = logging.Logger(__file__) logger = logging.Logger(__file__)
if __name__ == '__main__': if __name__ == "__main__":
has_changed = False has_changed = False
for root, dir, filenames in os.walk("vnpy"): for root, dir, filenames in os.walk("vnpy"):
for filename in filenames: for filename in filenames:
basename, ext = os.path.splitext(filename) basename, ext = os.path.splitext(filename)
if ext == '.py': if ext == ".py":
path = os.path.join(root, filename) path = os.path.join(root, filename)
reformatted_code, encoding, changed = FormatFile(filename=path, reformatted_code, encoding, changed = FormatFile(
style_config='.style.yapf', filename=path,
print_diff=True, style_config=".style.yapf",
verify=False, print_diff=True,
in_place=False, verify=False,
logger=None in_place=False,
) logger=None,
)
if changed: if changed:
has_changed = True has_changed = True
logger.warning("File {} not formatted!".format(path)) logger.warning("File {} not formatted!".format(path))

View File

@ -12,10 +12,10 @@ from typing import Any, Callable, Optional
class RequestStatus(Enum): class RequestStatus(Enum):
ready = 0 # Request created ready = 0 # Request created
success = 1 # Request successful (status code 2xx) success = 1 # Request successful (status code 2xx)
failed = 2 # Request failed (status code not 2xx) failed = 2 # Request failed (status code not 2xx)
error = 3 # Exception raised error = 3 # Exception raised
class Request(object): class Request(object):
@ -24,16 +24,16 @@ class Request(object):
""" """
def __init__( def __init__(
self, self,
method: str, method: str,
path: str, path: str,
params: dict, params: dict,
data: dict, data: dict,
headers: dict, headers: dict,
callback: Callable, callback: Callable,
on_failed: Callable = None, on_failed: Callable = None,
on_error: Callable = None, on_error: Callable = None,
extra: Any = None extra: Any = None,
): ):
"""""" """"""
self.method = method self.method = method
@ -52,7 +52,7 @@ class Request(object):
def __str__(self): def __str__(self):
if self.response is None: if self.response is None:
status_code = 'terminated' status_code = "terminated"
else: else:
status_code = self.response.status_code status_code = self.response.status_code
@ -70,7 +70,7 @@ class Request(object):
self.headers, self.headers,
self.params, self.params,
self.data, self.data,
'' if self.response is None else self.response.text "" if self.response is None else self.response.text,
) )
) )
@ -88,11 +88,11 @@ class RestClient(object):
def __init__(self): def __init__(self):
""" """
""" """
self.url_base = None # type: str self.url_base = None # type: str
self._active = False self._active = False
self._queue = Queue() self._queue = Queue()
self._pool = None # type: Pool self._pool = None # type: Pool
self.proxies = None self.proxies = None
@ -135,16 +135,16 @@ class RestClient(object):
self._queue.join() self._queue.join()
def add_request( def add_request(
self, self,
method: str, method: str,
path: str, path: str,
callback: Callable, callback: Callable,
params: dict = None, params: dict = None,
data: dict = None, data: dict = None,
headers: dict = None, headers: dict = None,
on_failed: Callable = None, on_failed: Callable = None,
on_error: Callable = None, on_error: Callable = None,
extra: Any = None extra: Any = None,
): ):
""" """
Add a new request. Add a new request.
@ -160,15 +160,7 @@ class RestClient(object):
:return: Request :return: Request
""" """
request = Request( request = Request(
method, method, path, params, data, headers, callback, on_failed, on_error, extra
path,
params,
data,
headers,
callback,
on_failed,
on_error,
extra
) )
self._queue.put(request) self._queue.put(request)
return request return request
@ -204,54 +196,34 @@ class RestClient(object):
sys.stderr.write(str(request)) sys.stderr.write(str(request))
def on_error( def on_error(
self, self, exception_type: type, exception_value: Exception, tb, request: Request
exception_type: type,
exception_value: Exception,
tb,
request: Request
): ):
""" """
Default on_error handler for Python exception. Default on_error handler for Python exception.
""" """
sys.stderr.write( sys.stderr.write(
self.exception_detail(exception_type, self.exception_detail(exception_type, exception_value, tb, request)
exception_value,
tb,
request)
) )
sys.excepthook(exception_type, exception_value, tb) sys.excepthook(exception_type, exception_value, tb)
def exception_detail( def exception_detail(
self, self, exception_type: type, exception_value: Exception, tb, request: Request
exception_type: type,
exception_value: Exception,
tb,
request: Request
): ):
text = "[{}]: Unhandled RestClient Error:{}\n".format( text = "[{}]: Unhandled RestClient Error:{}\n".format(
datetime.now().isoformat(), datetime.now().isoformat(), exception_type
exception_type
) )
text += "request:{}\n".format(request) text += "request:{}\n".format(request)
text += "Exception trace: \n" text += "Exception trace: \n"
text += "".join( text += "".join(traceback.format_exception(exception_type, exception_value, tb))
traceback.format_exception(
exception_type,
exception_value,
tb,
)
)
return text return text
def _process_request( def _process_request(
self, self, request: Request, session: requests.session
request: Request, ): # type: (Request, requests.Session)->None
session: requests.session
): # type: (Request, requests.Session)->None
""" """
Sending request to server and get result. Sending request to server and get result.
""" """
# noinspection PyBroadException # noinspection PyBroadException
try: try:
request = self.sign(request) request = self.sign(request)
@ -263,12 +235,12 @@ class RestClient(object):
headers=request.headers, headers=request.headers,
params=request.params, params=request.params,
data=request.data, data=request.data,
proxies=self.proxies proxies=self.proxies,
) )
request.response = response request.response = response
status_code = response.status_code status_code = response.status_code
if status_code / 100 == 2: # 2xx都算成功尽管交易所都用200 if status_code / 100 == 2: # 2xx都算成功尽管交易所都用200
jsonBody = response.json() jsonBody = response.json()
request.callback(jsonBody, request) request.callback(jsonBody, request)
request.status = RequestStatus.success request.status = RequestStatus.success

View File

@ -123,9 +123,9 @@ class WebsocketClient(object):
"""""" """"""
self._ws = self._create_connection( self._ws = self._create_connection(
self.host, self.host,
sslopt={'cert_reqs': ssl.CERT_NONE}, sslopt={"cert_reqs": ssl.CERT_NONE},
http_proxy_host=self.proxy_host, http_proxy_host=self.proxy_host,
http_proxy_port=self.proxy_port http_proxy_port=self.proxy_port,
) )
self.on_connected() self.on_connected()
@ -166,7 +166,7 @@ class WebsocketClient(object):
try: try:
data = self.unpack_data(text) data = self.unpack_data(text)
except ValueError as e: except ValueError as e:
print('websocket unable to parse data: ' + text) print("websocket unable to parse data: " + text)
raise e raise e
self.on_packet(data) self.on_packet(data)
@ -211,7 +211,7 @@ class WebsocketClient(object):
"""""" """"""
ws = self._get_ws() ws = self._get_ws()
if ws: if ws:
ws.send('ping', websocket.ABNF.OPCODE_PING) ws.send("ping", websocket.ABNF.OPCODE_PING)
@staticmethod @staticmethod
def on_connected(): def on_connected():
@ -238,36 +238,20 @@ class WebsocketClient(object):
""" """
Callback when exception raised. Callback when exception raised.
""" """
sys.stderr.write( sys.stderr.write(self.exception_detail(exception_type, exception_value, tb))
self.exception_detail(exception_type,
exception_value,
tb)
)
return sys.excepthook(exception_type, exception_value, tb) return sys.excepthook(exception_type, exception_value, tb)
def exception_detail( def exception_detail(self, exception_type: type, exception_value: Exception, tb):
self,
exception_type: type,
exception_value: Exception,
tb
):
""" """
Print detailed exception information. Print detailed exception information.
""" """
text = "[{}]: Unhandled WebSocket Error:{}\n".format( text = "[{}]: Unhandled WebSocket Error:{}\n".format(
datetime.now().isoformat(), datetime.now().isoformat(), exception_type
exception_type
) )
text += "LastSentText:\n{}\n".format(self._last_sent_text) text += "LastSentText:\n{}\n".format(self._last_sent_text)
text += "LastReceivedText:\n{}\n".format(self._last_received_text) text += "LastReceivedText:\n{}\n".format(self._last_received_text)
text += "Exception trace: \n" text += "Exception trace: \n"
text += "".join( text += "".join(traceback.format_exception(exception_type, exception_value, tb))
traceback.format_exception(
exception_type,
exception_value,
tb,
)
)
return text return text
def _record_last_sent_text(self, text: str): def _record_last_sent_text(self, text: str):
@ -280,4 +264,4 @@ class WebsocketClient(object):
""" """
Record last received text for debug purpose. Record last received text for debug purpose.
""" """
self._last_received_text = text[:1000] self._last_received_text = text[:1000]

View File

@ -8,10 +8,11 @@ from .base import APP_NAME
class CtaStrategyApp(BaseApp): class CtaStrategyApp(BaseApp):
"""""" """"""
app_name = APP_NAME app_name = APP_NAME
app_module = __module__ app_module = __module__
app_path = Path(__file__).parent app_path = Path(__file__).parent
display_name = "CTA策略" display_name = "CTA策略"
engine_class = CtaEngine engine_class = CtaEngine
widget_name = "CtaManager" widget_name = "CtaManager"
icon_name = "cta.ico" icon_name = "cta.ico"

View File

@ -4,6 +4,7 @@ from collections import defaultdict
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns
from pandas import DataFrame from pandas import DataFrame
from vnpy.trader.constant import Interval, Status, Direction, Exchange from vnpy.trader.constant import Interval, Status, Direction, Exchange
@ -18,13 +19,16 @@ from .base import (
EngineType, EngineType,
StopOrder, StopOrder,
BacktestingMode, BacktestingMode,
ORDER_CTA2VT ORDER_CTA2VT,
) )
from .template import CtaTemplate from .template import CtaTemplate
sns.set_style("whitegrid")
class BacktestingEngine: class BacktestingEngine:
"""""" """"""
engine_type = EngineType.BACKTESTING engine_type = EngineType.BACKTESTING
gateway_name = "BACKTESTING" gateway_name = "BACKTESTING"
@ -39,7 +43,7 @@ class BacktestingEngine:
self.slippage = 0 self.slippage = 0
self.size = 1 self.size = 1
self.pricetick = 0 self.pricetick = 0
self.capital = 1000000 self.capital = 1_000_000
self.mode = BacktestingMode.BAR self.mode = BacktestingMode.BAR
self.strategy = None self.strategy = None
@ -92,17 +96,17 @@ class BacktestingEngine:
self.daily_results.clear() self.daily_results.clear()
def set_parameters( def set_parameters(
self, self,
vt_symbol: str, vt_symbol: str,
interval: Interval, interval: Interval,
start: datetime, start: datetime,
rate: float, rate: float,
slippage: float, slippage: float,
size: float, size: float,
pricetick: float, pricetick: float,
capital: int = 0, capital: int = 0,
end: datetime = None, end: datetime = None,
mode: BacktestingMode = None mode: BacktestingMode = None,
): ):
"""""" """"""
self.mode = mode self.mode = mode
@ -124,10 +128,7 @@ class BacktestingEngine:
def add_strategy(self, strategy_class: type, setting: dict): def add_strategy(self, strategy_class: type, setting: dict):
"""""" """"""
self.strategy = strategy_class( self.strategy = strategy_class(
self, self, strategy_class.__name__, self.vt_symbol, setting
strategy_class.__name__,
self.vt_symbol,
setting
) )
def load_data(self): def load_data(self):
@ -135,18 +136,26 @@ class BacktestingEngine:
self.output("开始加载历史数据") self.output("开始加载历史数据")
if self.mode == BacktestingMode.BAR: if self.mode == BacktestingMode.BAR:
s = DbBarData.select().where( s = (
(DbBarData.vt_symbol == self.vt_symbol) & DbBarData.select()
(DbBarData.interval == self.interval) & .where(
(DbBarData.datetime >= self.start) & (DbBarData.vt_symbol == self.vt_symbol)
(DbBarData.datetime <= self.end) & (DbBarData.interval == self.interval)
).order_by(DbBarData.datetime) & (DbBarData.datetime >= self.start)
& (DbBarData.datetime <= self.end)
)
.order_by(DbBarData.datetime)
)
else: else:
s = DbTickData.select().where( s = (
(DbTickData.vt_symbol == self.vt_symbol) & DbTickData.select()
(DbTickData.datetime >= self.start) & .where(
(DbTickData.datetime <= self.end) (DbTickData.vt_symbol == self.vt_symbol)
).order_by(DbTickData.datetime) & (DbTickData.datetime >= self.start)
& (DbTickData.datetime <= self.end)
)
.order_by(DbTickData.datetime)
)
self.history_data = list(s) self.history_data = list(s)
@ -164,7 +173,7 @@ class BacktestingEngine:
# Use the first [days] of history data for initializing strategy # Use the first [days] of history data for initializing strategy
day_count = 0 day_count = 0
for ix, data in enumerate(self.history_data): for ix, data in enumerate(self.history_data):
if (self.datetime and data.datetime.day != self.datetime.day): if self.datetime and data.datetime.day != self.datetime.day:
day_count += 1 day_count += 1
if day_count >= self.days: if day_count >= self.days:
break break
@ -205,11 +214,7 @@ class BacktestingEngine:
for daily_result in self.daily_results.values(): for daily_result in self.daily_results.values():
daily_result.calculate_pnl( daily_result.calculate_pnl(
pre_close, pre_close, start_pos, self.size, self.rate, self.slippage
start_pos,
self.size,
self.rate,
self.slippage
) )
pre_close = daily_result.close_price pre_close = daily_result.close_price
@ -236,78 +241,82 @@ class BacktestingEngine:
# Calculate balance related time series data # Calculate balance related time series data
df["balance"] = df["net_pnl"].cumsum() + self.capital df["balance"] = df["net_pnl"].cumsum() + self.capital
df["return"] = (np.log(df["balance}" - np.log(df["balance"].shift(1))).fillna(0) df["return"] = (np.log(df["balance"] - np.log(df["balance"].shift(1)))).fillna(
df["highlevel"] = df["balance"].rolling(min_periods=1,window=len(df),center=False).max() 0
df["drawdown"] = df["balance"] - df["highlevel"] )
df["highlevel"] = (
df["balance"].rolling(min_periods=1, window=len(df), center=False).max()
)
df["drawdown"] = df["balance"] - df["highlevel"]
df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100 df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100
# Calculate statistics value # Calculate statistics value
start_date = df.index[0] start_date = df.index[0]
end_date = df.index[-1] end_date = df.index[-1]
total_days = len(df) total_days = len(df)
profit_days = len(df[df["netPnl"]>0]) profit_days = len(df[df["netPnl"] > 0])
loss_days = len(df[df["netPnl"]<0]) loss_days = len(df[df["netPnl"] < 0])
end_balance = df["balance"].iloc[-1] end_balance = df["balance"].iloc[-1]
max_drawdown = df["drawdown"].min() max_drawdown = df["drawdown"].min()
max_ddpercent = df["ddpercent"].min() max_ddpercent = df["ddpercent"].min()
total_net_pnl = df["net_pnl"].sum() total_net_pnl = df["net_pnl"].sum()
daily_net_pnl = total_net_pnl / total_days daily_net_pnl = total_net_pnl / total_days
total_commission = df["commission"].sum() total_commission = df["commission"].sum()
daily_commission = total_commission / total_days daily_commission = total_commission / total_days
total_slippage = df["slippage"].sum() total_slippage = df["slippage"].sum()
daily_slippage = total_slippage / total_days daily_slippage = total_slippage / total_days
total_turnover = df["turnover"].sum() total_turnover = df["turnover"].sum()
daily_turnover = total_turnover / total_days daily_turnover = total_turnover / total_days
total_trade_count = df["trade_count"].sum() total_trade_count = df["trade_count"].sum()
daily_trade_count = total_trade_count / total_days daily_trade_count = total_trade_count / total_days
total_return = (end_balance/self.capital - 1) * 100 total_return = (end_balance / self.capital - 1) * 100
annual_return = total_return / total_days * 240 annual_return = total_return / total_days * 240
daily_return = df["return"].mean() * 100 daily_return = df["return"].mean() * 100
return_std = df["return"].std() * 100 return_std = df["return"].std() * 100
if return_std: if return_std:
sharpe_ratio = daily_return / return_std * np.sqrt(240) sharpe_ratio = daily_return / return_std * np.sqrt(240)
else: else:
sharpe_ratio = 0 sharpe_ratio = 0
# Output # Output
# 输出统计结果 # 输出统计结果
self.output("-" * 30) self.output("-" * 30)
self.output(f"首个交易日:\t{start_date}") self.output(f"首个交易日:\t{start_date}")
self.output(f"最后交易日:\t{end_date}") self.output(f"最后交易日:\t{end_date}")
self.output(f"总交易日:\t{total_days}") self.output(f"总交易日:\t{total_days}")
self.output(f"盈利交易日:\t{profit_days}") self.output(f"盈利交易日:\t{profit_days}")
self.output(f"亏损交易日:\t{loss_days}") self.output(f"亏损交易日:\t{loss_days}")
self.output(f"起始资金:\t{self.capital:,.2f}") self.output(f"起始资金:\t{self.capital:,.2f}")
self.output(f"结束资金:\t{end_balance:,.2f}") self.output(f"结束资金:\t{end_balance:,.2f}")
self.output(f"总收益率:\t{total_return:,.2f}%") self.output(f"总收益率:\t{total_return:,.2f}%")
self.output(f"年化收益:\t{annual_return:,.2f}%") self.output(f"年化收益:\t{annual_return:,.2f}%")
self.output(f"最大回撤: \t{max_drawdown:,.2f}%") self.output(f"最大回撤: \t{max_drawdown:,.2f}%")
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%") self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
self.output(f"总盈亏:\t{total_net_pnl:,.2f}%") self.output(f"总盈亏:\t{total_net_pnl:,.2f}%")
self.output(f"总手续费:\t{total_commission:,.2f}") self.output(f"总手续费:\t{total_commission:,.2f}")
self.output(f"总滑点:\t{total_slippage:,.2f}") self.output(f"总滑点:\t{total_slippage:,.2f}")
self.output(f"总成交金额:\t{total_turnover:,.2f}") self.output(f"总成交金额:\t{total_turnover:,.2f}")
self.output(f"总成交笔数:\t{total_trade_count}") self.output(f"总成交笔数:\t{total_trade_count}")
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}") self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
self.output(f"日均手续费:\t{daily_commission:,.2f}") self.output(f"日均手续费:\t{daily_commission:,.2f}")
self.output(f"日均滑点:\t{daily_slippage:,.2f}") self.output(f"日均滑点:\t{daily_slippage:,.2f}")
self.output(f"日均成交金额:\t{daily_turnover:,.2f}") self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
self.output(f"日均成交笔数:\t{daily_trade_count}") self.output(f"日均成交笔数:\t{daily_trade_count}")
self.output(f"日均收益率:\t{daily_return:,.2f}%") self.output(f"日均收益率:\t{daily_return:,.2f}%")
self.output(f"收益标准差:\t{return_std:,.2f}%") self.output(f"收益标准差:\t{return_std:,.2f}%")
self.output(f"Sharpe Ratio\t{sharpe_ratio:,.2f}") self.output(f"Sharpe Ratio\t{sharpe_ratio:,.2f}")
@ -335,33 +344,33 @@ class BacktestingEngine:
"annual_return": annual_return, "annual_return": annual_return,
"daily_return": daily_return, "daily_return": daily_return,
"return_std": return_std, "return_std": return_std,
"sharpe_ratio": sharpe_ratio "sharpe_ratio": sharpe_ratio,
} }
return statistics return statistics
def show_chart(self, df: DataFrame = None): def show_chart(self, df: DataFrame = None):
"""""" """"""
if not df: if not df:
df = self.daily_df df = self.daily_df
fig = plt.figure(figsize=(10, 16)) fig = plt.figure(figsize=(10, 16))
balance_plot = plt.subplot(4, 1, 1) balance_plot = plt.subplot(4, 1, 1)
balance_plot.set_title('Balance') balance_plot.set_title("Balance")
df['balance'].plot(legend=True) df["balance"].plot(legend=True)
drawdown_plot = plt.subplot(4, 1, 2) drawdown_plot = plt.subplot(4, 1, 2)
drawdown_plot.set_title('Drawdown') drawdown_plot.set_title("Drawdown")
drawdown_plot.fill_between(range(len(df)), df['drawdown'].values) drawdown_plot.fill_between(range(len(df)), df["drawdown"].values)
pnl_plot = plt.subplot(4, 1, 3) pnl_plot = plt.subplot(4, 1, 3)
pnl_plot.set_title('Daily Pnl') pnl_plot.set_title("Daily Pnl")
df['net_pnl'].plot(kind='bar', legend=False, grid=False, xticks=[]) df["net_pnl"].plot(kind="bar", legend=False, grid=False, xticks=[])
distribution_plot = plt.subplot(4, 1, 4) distribution_plot = plt.subplot(4, 1, 4)
distribution_plot.set_title('Daily Pnl Distribution') distribution_plot.set_title("Daily Pnl Distribution")
df['net_pnl'].hist(bins=50) df["net_pnl"].hist(bins=50)
plt.show() plt.show()
@ -421,12 +430,14 @@ class BacktestingEngine:
# Check whether limit orders can be filled. # Check whether limit orders can be filled.
long_cross = ( long_cross = (
order.direction == Direction.LONG order.direction == Direction.LONG
and order.price >= long_cross_price and long_cross_price > 0 and order.price >= long_cross_price
and long_cross_price > 0
) )
short_cross = ( short_cross = (
order.direction == Direction.SHORT order.direction == Direction.SHORT
and order.price <= short_cross_price and short_cross_price > 0 and order.price <= short_cross_price
and short_cross_price > 0
) )
if not long_cross and not short_cross: if not long_cross and not short_cross:
@ -459,7 +470,7 @@ class BacktestingEngine:
price=trade_price, price=trade_price,
volume=order.volume, volume=order.volume,
time=self.datetime.strftime("%H:%M:%S"), time=self.datetime.strftime("%H:%M:%S"),
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
trade.datetime = self.datetime trade.datetime = self.datetime
@ -510,7 +521,7 @@ class BacktestingEngine:
price=stop_order.price, price=stop_order.price,
volume=stop_order.volume, volume=stop_order.volume,
status=Status.ALLTRADED, status=Status.ALLTRADED,
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.limit_orders[order.vt_orderid] = order self.limit_orders[order.vt_orderid] = order
@ -535,7 +546,7 @@ class BacktestingEngine:
price=trade_price, price=trade_price,
volume=order.volume, volume=order.volume,
time=self.datetime.strftime("%H:%M:%S"), time=self.datetime.strftime("%H:%M:%S"),
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
trade.datetime = self.datetime trade.datetime = self.datetime
@ -553,11 +564,7 @@ class BacktestingEngine:
self.strategy.on_trade(trade) self.strategy.on_trade(trade)
def load_bar( def load_bar(
self, self, vt_symbol: str, days: int, interval: Interval, callback: Callable
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable
): ):
"""""" """"""
self.days = days self.days = days
@ -569,12 +576,12 @@ class BacktestingEngine:
self.callback = callback self.callback = callback
def send_order( def send_order(
self, self,
strategy: CtaTemplate, strategy: CtaTemplate,
order_type: CtaOrderType, order_type: CtaOrderType,
price: float, price: float,
volume: float, volume: float,
stop: bool = False stop: bool = False,
): ):
"""""" """"""
if stop: if stop:
@ -582,12 +589,7 @@ class BacktestingEngine:
else: else:
return self.send_limit_order(order_type, price, volume) return self.send_limit_order(order_type, price, volume)
def send_stop_order( def send_stop_order(self, order_type: CtaOrderType, price: float, volume: float):
self,
order_type: CtaOrderType,
price: float,
volume: float
):
"""""" """"""
self.stop_order_count += 1 self.stop_order_count += 1
@ -597,7 +599,7 @@ class BacktestingEngine:
price=price, price=price,
volume=volume, volume=volume,
stop_orderid=f"{STOPORDER_PREFIX}.{self.stop_order_count}", stop_orderid=f"{STOPORDER_PREFIX}.{self.stop_order_count}",
strategy_name=self.strategy.strategy_name strategy_name=self.strategy.strategy_name,
) )
self.active_stop_orders[stop_order.stop_orderid] = stop_order self.active_stop_orders[stop_order.stop_orderid] = stop_order
@ -605,12 +607,7 @@ class BacktestingEngine:
return stop_order.stop_orderid return stop_order.stop_orderid
def send_limit_order( def send_limit_order(self, order_type: CtaOrderType, price: float, volume: float):
self,
order_type: CtaOrderType,
price: float,
volume: float
):
"""""" """"""
self.limit_order_count += 1 self.limit_order_count += 1
direction, offset = ORDER_CTA2VT[order_type] direction, offset = ORDER_CTA2VT[order_type]
@ -624,7 +621,7 @@ class BacktestingEngine:
price=price, price=price,
volume=volume, volume=volume,
status=Status.NOTTRADED, status=Status.NOTTRADED,
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.active_limit_orders[order.vt_orderid] = order self.active_limit_orders[order.vt_orderid] = order
@ -724,19 +721,17 @@ class DailyResult:
self.trades.append(trade) self.trades.append(trade)
def calculate_pnl( def calculate_pnl(
self, self,
pre_close: float, pre_close: float,
start_pos: float, start_pos: float,
size: int, size: int,
rate: float, rate: float,
slippage: float slippage: float,
): ):
"""""" """"""
# Holding pnl is the pnl from holding position at day start # Holding pnl is the pnl from holding position at day start
self.start_pos = self.end_pos = start_pos self.start_pos = self.end_pos = start_pos
self.holding_pnl = self.start_pos * ( self.holding_pnl = self.start_pos * (self.close_price - self.pre_close) * size
self.close_price - self.pre_close
) * size
# Trading pnl is the pnl from new trade during the day # Trading pnl is the pnl from new trade during the day
self.trade_count = len(self.trades) self.trade_count = len(self.trades)
@ -749,9 +744,7 @@ class DailyResult:
pos_change -= trade.volume pos_change -= trade.volume
turnover = trade.price * trade.volume * size turnover = trade.price * trade.volume * size
self.trading_pnl += pos_change * ( self.trading_pnl += pos_change * (self.close_price - trade.price) * size
self.close_price - trade.price
) * size
self.end_pos += pos_change self.end_pos += pos_change
self.turnover += turnover self.turnover += turnover
self.commission += turnover * rate self.commission += turnover * rate
@ -760,3 +753,39 @@ class DailyResult:
# Net pnl takes account of commission and slippage cost # Net pnl takes account of commission and slippage cost
self.total_pnl = self.trading_pnl + self.holding_pnl self.total_pnl = self.trading_pnl + self.holding_pnl
self.net_pnl = self.total_pnl - self.commission - self.slippage self.net_pnl = self.total_pnl - self.commission - self.slippage
class OptimizationSetting:
"""
Setting for runnning optimization.
"""
def __init__(self):
""""""
self.params = {}
self.target = ""
def add_parameter(
self, name: str, start: float, end: float = None, step: float = None
):
""""""
if not end and not step:
self.params[name] = [start]
return
if start >= end:
print("参数优化起始点必须小于终止点")
return
if step <= 0:
print("参数优化步进必须大于0")
return
value = start
value_list = []
while value <= end:
value_list.append(value)
value += step
self.params[name] = value_list

View File

@ -51,17 +51,13 @@ class StopOrder:
self.direction, self.offset = ORDER_CTA2VT[self.order_type] self.direction, self.offset = ORDER_CTA2VT[self.order_type]
EVENT_CTA_LOG = 'eCtaLog' EVENT_CTA_LOG = "eCtaLog"
EVENT_CTA_STRATEGY = 'eCtaStrategy' EVENT_CTA_STRATEGY = "eCtaStrategy"
EVENT_CTA_STOPORDER = 'eCtaStopOrder' EVENT_CTA_STOPORDER = "eCtaStopOrder"
ORDER_CTA2VT = { ORDER_CTA2VT = {
CtaOrderType.BUY: (Direction.LONG, CtaOrderType.BUY: (Direction.LONG, Offset.OPEN),
Offset.OPEN), CtaOrderType.SELL: (Direction.SHORT, Offset.CLOSE),
CtaOrderType.SELL: (Direction.SHORT, CtaOrderType.SHORT: (Direction.SHORT, Offset.OPEN),
Offset.CLOSE), CtaOrderType.COVER: (Direction.LONG, Offset.CLOSE),
CtaOrderType.SHORT: (Direction.SHORT,
Offset.OPEN),
CtaOrderType.COVER: (Direction.LONG,
Offset.CLOSE),
} }

View File

@ -15,7 +15,7 @@ from vnpy.trader.object import (
CancelRequest, CancelRequest,
SubscribeRequest, SubscribeRequest,
LogData, LogData,
TickData TickData,
) )
from vnpy.trader.event import EVENT_TICK, EVENT_ORDER, EVENT_TRADE from vnpy.trader.event import EVENT_TICK, EVENT_ORDER, EVENT_TRADE
from vnpy.trader.constant import Direction, Offset, Exchange, PriceType, Interval from vnpy.trader.constant import Direction, Offset, Exchange, PriceType, Interval
@ -31,36 +31,32 @@ from .base import (
ORDER_CTA2VT, ORDER_CTA2VT,
EVENT_CTA_LOG, EVENT_CTA_LOG,
EVENT_CTA_STRATEGY, EVENT_CTA_STRATEGY,
EVENT_CTA_STOPORDER EVENT_CTA_STOPORDER,
) )
class CtaEngine(BaseEngine): class CtaEngine(BaseEngine):
"""""" """"""
engine_type = EngineType.LIVE # live trading engine
engine_type = EngineType.LIVE # live trading engine
filename = "CtaStrategy.vt" filename = "CtaStrategy.vt"
def __init__(self, main_engine: MainEngine, event_engine: EventEngine): def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
"""""" """"""
super(CtaEngine, super(CtaEngine, self).__init__(main_engine, event_engine, "CtaStrategy")
self).__init__(main_engine,
event_engine,
"CtaStrategy")
self.setting_file = None # setting file object self.setting_file = None # setting file object
self.classes = {} # class_name: stategy_class self.classes = {} # class_name: stategy_class
self.strategies = {} # strategy_name: strategy self.strategies = {} # strategy_name: strategy
self.symbol_strategy_map = defaultdict(list) # vt_symbol: strategy list self.symbol_strategy_map = defaultdict(list) # vt_symbol: strategy list
self.orderid_strategy_map = {} # vt_orderid: strategy self.orderid_strategy_map = {} # vt_orderid: strategy
self.strategy_orderid_map = defaultdict( self.strategy_orderid_map = defaultdict(set) # strategy_name: orderid list
set
) # strategy_name: orderid list
self.stop_order_count = 0 # for generating stop_orderid self.stop_order_count = 0 # for generating stop_orderid
self.stop_orders = {} # stop_orderid: stop_order self.stop_orders = {} # stop_orderid: stop_order
def init_engine(self): def init_engine(self):
""" """
@ -131,12 +127,10 @@ class CtaEngine(BaseEngine):
continue continue
long_triggered = ( long_triggered = (
so.direction == Direction.LONG so.direction == Direction.LONG and tick.last_price >= stop_order.price
and tick.last_price >= stop_order.price
) )
short_triggered = ( short_triggered = (
so.direction == Direction.SHORT so.direction == Direction.SHORT and tick.last_price <= stop_order.price
and tick.last_price <= stop_order.price
) )
if long_triggered or short_triggered: if long_triggered or short_triggered:
@ -157,10 +151,7 @@ class CtaEngine(BaseEngine):
price = tick.bid_price_5 price = tick.bid_price_5
vt_orderid = self.send_limit_order( vt_orderid = self.send_limit_order(
strategy, strategy, stop_order.order_type, price, stop_order.volume
stop_order.order_type,
price,
stop_order.volume
) )
# Update stop order status if placed successfully # Update stop order status if placed successfully
@ -177,17 +168,15 @@ class CtaEngine(BaseEngine):
stop_order.vt_orderid = vt_orderid stop_order.vt_orderid = vt_orderid
self.call_strategy_func( self.call_strategy_func(
strategy, strategy, strategy.on_stop_order, stop_order
strategy.on_stop_order,
stop_order
) )
def send_limit_order( def send_limit_order(
self, self,
strategy: CtaTemplate, strategy: CtaTemplate,
order_type: CtaOrderType, order_type: CtaOrderType,
price: float, price: float,
volume: float volume: float,
): ):
""" """
Send a new order. Send a new order.
@ -207,12 +196,9 @@ class CtaEngine(BaseEngine):
offset=offset, offset=offset,
price_type=PriceType.LIMIT, price_type=PriceType.LIMIT,
price=price, price=price,
volume=volume volume=volume,
)
vt_orderid = self.main_engine.send_limit_order(
req,
contract.gateway_name
) )
vt_orderid = self.main_engine.send_limit_order(req, contract.gateway_name)
# Save relationship between orderid and strategy. # Save relationship between orderid and strategy.
self.orderid_strategy_map[vt_orderid] = strategy self.orderid_strategy_map[vt_orderid] = strategy
@ -223,11 +209,11 @@ class CtaEngine(BaseEngine):
return vt_orderid return vt_orderid
def send_stop_order( def send_stop_order(
self, self,
strategy: CtaTemplate, strategy: CtaTemplate,
order_type: CtaOrderType, order_type: CtaOrderType,
price: float, price: float,
volume: float volume: float,
): ):
""" """
Send a new order. Send a new order.
@ -243,7 +229,7 @@ class CtaEngine(BaseEngine):
price=price, price=price,
volume=volume, volume=volume,
stop_orderid=stop_orderid, stop_orderid=stop_orderid,
strategy_name=strategy.strategy_name strategy_name=strategy.strategy_name,
) )
self.stop_orders[stop_orderid] = stop_order self.stop_orders[stop_orderid] = stop_order
@ -289,12 +275,12 @@ class CtaEngine(BaseEngine):
self.call_strategy_func(strategy, strategy.on_stop_order, stop_order) self.call_strategy_func(strategy, strategy.on_stop_order, stop_order)
def send_order( def send_order(
self, self,
strategy: CtaTemplate, strategy: CtaTemplate,
order_type: CtaOrderType, order_type: CtaOrderType,
price: float, price: float,
volume: float, volume: float,
stop: bool stop: bool,
): ):
""" """
""" """
@ -327,11 +313,7 @@ class CtaEngine(BaseEngine):
return self.engine_type return self.engine_type
def load_bar( def load_bar(
self, self, vt_symbol: str, days: int, interval: Interval, callback: Callable
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable
): ):
"""""" """"""
pass pass
@ -341,10 +323,7 @@ class CtaEngine(BaseEngine):
pass pass
def call_strategy_func( def call_strategy_func(
self, self, strategy: CtaTemplate, func: Callable, params: Any = None
strategy: CtaTemplate,
func: Callable,
params: Any = None
): ):
""" """
Call function of a strategy and catch any exception raised. Call function of a strategy and catch any exception raised.
@ -362,11 +341,7 @@ class CtaEngine(BaseEngine):
self.write_log(msg, strategy) self.write_log(msg, strategy)
def add_strategy( def add_strategy(
self, self, class_name: str, strategy_name: str, vt_symbol: str, setting: dict
class_name: str,
strategy_name: str,
vt_symbol: str,
setting: dict
): ):
""" """
Add a new strategy. Add a new strategy.
@ -462,29 +437,18 @@ class CtaEngine(BaseEngine):
Load strategy class from source code. Load strategy class from source code.
""" """
path1 = Path(__file__).parent.joinpath("strategies") path1 = Path(__file__).parent.joinpath("strategies")
self.load_strategy_class_from_folder( self.load_strategy_class_from_folder(path1, "vnpy.app.cta_strategy.strategies")
path1,
"vnpy.app.cta_strategy.strategies"
)
path2 = Path.cwd().joinpath("strategies") path2 = Path.cwd().joinpath("strategies")
self.load_strategy_class_from_folder(path2, "strategies") self.load_strategy_class_from_folder(path2, "strategies")
def load_strategy_class_from_folder( def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
self,
path: Path,
module_name: str = ""
):
""" """
Load strategy class from certain folder. Load strategy class from certain folder.
""" """
for dirpath, dirnames, filenames in os.walk(path): for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames: for filename in filenames:
module_name = ".".join( module_name = ".".join([module_name, filename.replace(".py", "")])
[module_name,
filename.replace(".py",
"")]
)
self.load_strategy_class_from_module(module_name) self.load_strategy_class_from_module(module_name)
def load_strategy_class_from_module(self, module_name: str): def load_strategy_class_from_module(self, module_name: str):
@ -566,7 +530,7 @@ class CtaEngine(BaseEngine):
strategy.__class__.__name__, strategy.__class__.__name__,
strategy_name, strategy_name,
strategy.vt_symbol, strategy.vt_symbol,
setting setting,
) )
self.setting_file.sync() self.setting_file.sync()
@ -611,4 +575,4 @@ class CtaEngine(BaseEngine):
log = LogData(msg=msg, gateway_name="CtaStrategy") log = LogData(msg=msg, gateway_name="CtaStrategy")
event = Event(type=EVENT_CTA_LOG, data=log) event = Event(type=EVENT_CTA_LOG, data=log)
self.event_engine.put(event) self.event_engine.put(event)

View File

@ -16,8 +16,6 @@ class DoubleMaStrategy(CtaTemplate):
def __init__(self, cta_engine, strategy_name, vt_symbol, setting): def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
"""""" """"""
super(DoubleMaStrategy, super(DoubleMaStrategy, self).__init__(
self).__init__(cta_engine, cta_engine, strategy_name, vt_symbol, setting
strategy_name, )
vt_symbol,
setting)

View File

@ -1,7 +1,7 @@
"""""" """"""
from abc import ABC from abc import ABC
from typing import Any from typing import Any, Callable
from vnpy.trader.object import TickData, OrderData, TradeData, BarData from vnpy.trader.object import TickData, OrderData, TradeData, BarData
from vnpy.trader.constant import Interval from vnpy.trader.constant import Interval
@ -17,11 +17,7 @@ class CtaTemplate(ABC):
variables = [] variables = []
def __init__( def __init__(
self, self, cta_engine: Any, strategy_name: str, vt_symbol: str, setting: dict
cta_engine: Any,
strategy_name: str,
vt_symbol: str,
setting: dict
): ):
"""""" """"""
self.cta_engine = cta_engine self.cta_engine = cta_engine
@ -84,7 +80,7 @@ class CtaTemplate(ABC):
"class_name": self.__class__.__name__, "class_name": self.__class__.__name__,
"author": self.author, "author": self.author,
"parameters": self.get_parameters(), "parameters": self.get_parameters(),
"variables": self.get_variables() "variables": self.get_variables(),
} }
return strategy_data return strategy_data
@ -155,11 +151,7 @@ class CtaTemplate(ABC):
return self.send_order(CtaOrderType.COVER, price, volume, stop) return self.send_order(CtaOrderType.COVER, price, volume, stop)
def send_order( def send_order(
self, self, order_type: CtaOrderType, price: float, volume: float, stop: bool = False
order_type: CtaOrderType,
price: float,
volume: float,
stop: bool = False
): ):
""" """
Send a new order. Send a new order.
@ -191,14 +183,14 @@ class CtaTemplate(ABC):
return self.cta_engine.get_engine_type() return self.cta_engine.get_engine_type()
def load_bar( def load_bar(
self, self, days: int, interval: Interval = Interval.MINUTE, callback: Callable = None
days: int,
interval: Interval = Interval.MINUTE,
callback=self.on_bar
): ):
""" """
Load historical bar data for initializing strategy. Load historical bar data for initializing strategy.
""" """
if not callback:
callback = self.on_bar
self.cta_engine.load_bar(self.vt_symbol, days, interval, callback) self.cta_engine.load_bar(self.vt_symbol, days, interval, callback)
def load_tick(self, days: int): def load_tick(self, days: int):

View File

@ -1 +1 @@
from .widget import CtaManager from .widget import CtaManager

View File

@ -11,6 +11,7 @@ from ..base import APP_NAME, EVENT_CTA_LOG, EVENT_CTA_STOPORDER, EVENT_CTA_STRAT
class CtaManager(QtWidgets.QWidget): class CtaManager(QtWidgets.QWidget):
"""""" """"""
signal_log = QtCore.pyqtSignal(Event) signal_log = QtCore.pyqtSignal(Event)
signal_strategy = QtCore.pyqtSignal(Event) signal_strategy = QtCore.pyqtSignal(Event)
@ -61,10 +62,7 @@ class CtaManager(QtWidgets.QWidget):
self.log_monitor = LogMonitor(self.main_engine, self.event_engine) self.log_monitor = LogMonitor(self.main_engine, self.event_engine)
self.stop_order_monitor = StopOrderMonitor( self.stop_order_monitor = StopOrderMonitor(self.main_engine, self.event_engine)
self.main_engine,
self.event_engine
)
# Set layout # Set layout
hbox1 = QtWidgets.QHBoxLayout() hbox1 = QtWidgets.QHBoxLayout()
@ -88,18 +86,13 @@ class CtaManager(QtWidgets.QWidget):
def update_class_combo(self): def update_class_combo(self):
"""""" """"""
self.class_combo.addItems( self.class_combo.addItems(self.cta_engine.get_all_strategy_class_names())
self.cta_engine.get_all_strategy_class_names()
)
def register_event(self): def register_event(self):
"""""" """"""
self.signal_strategy.connect(self.process_strategy_event) self.signal_strategy.connect(self.process_strategy_event)
self.event_engine.register( self.event_engine.register(EVENT_CTA_STRATEGY, self.signal_strategy.emit)
EVENT_CTA_STRATEGY,
self.signal_strategy.emit
)
def process_strategy_event(self, event): def process_strategy_event(self, event):
""" """
@ -136,12 +129,7 @@ class CtaManager(QtWidgets.QWidget):
vt_symbol = setting.pop("vt_symbol") vt_symbol = setting.pop("vt_symbol")
strategy_name = setting.pop("strategy_name") strategy_name = setting.pop("strategy_name")
self.cta_engine.add_strategy( self.cta_engine.add_strategy(class_name, strategy_name, vt_symbol, setting)
class_name,
strategy_name,
vt_symbol,
setting
)
def show(self): def show(self):
"""""" """"""
@ -153,12 +141,7 @@ class StrategyManager(QtWidgets.QFrame):
Manager for a strategy Manager for a strategy
""" """
def __init__( def __init__(self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict):
self,
cta_manager: CtaManager,
cta_engine: CtaEngine,
data: dict
):
"""""" """"""
super(StrategyManager, self).__init__() super(StrategyManager, self).__init__()
@ -277,9 +260,7 @@ class DataMonitor(QtWidgets.QTableWidget):
self.setHorizontalHeaderLabels(labels) self.setHorizontalHeaderLabels(labels)
self.setRowCount(1) self.setRowCount(1)
self.verticalHeader().setSectionResizeMode( self.verticalHeader().setSectionResizeMode(QtWidgets.QHeaderView.Stretch)
QtWidgets.QHeaderView.Stretch
)
self.verticalHeader().setVisible(False) self.verticalHeader().setVisible(False)
self.setEditTriggers(self.NoEditTriggers) self.setEditTriggers(self.NoEditTriggers)
@ -320,51 +301,20 @@ class StopOrderMonitor(BaseMonitor):
""" """
Monitor for local stop order. Monitor for local stop order.
""" """
event_type = EVENT_CTA_STOPORDER event_type = EVENT_CTA_STOPORDER
data_key = "stop_orderid" data_key = "stop_orderid"
sorting = True sorting = True
headers = { headers = {
"stop_orderid": { "stop_orderid": {"display": "停止委托号", "cell": BaseCell, "update": False},
"display": "停止委托号", "vt_orderid": {"display": "限价委托号", "cell": BaseCell, "update": True},
"cell": BaseCell, "vt_symbol": {"display": "代码", "cell": BaseCell, "update": False},
"update": False "order_type": {"display": "类型", "cell": EnumCell, "update": False},
}, "price": {"display": "价格", "cell": BaseCell, "update": False},
"vt_orderid": { "volume": {"display": "数量", "cell": BaseCell, "update": True},
"display": "限价委托号", "status": {"display": "状态", "cell": EnumCell, "update": True},
"cell": BaseCell, "strategy": {"display": "策略名", "cell": StrategyCell, "update": True},
"update": True
},
"vt_symbol": {
"display": "代码",
"cell": BaseCell,
"update": False
},
"order_type": {
"display": "类型",
"cell": EnumCell,
"update": False
},
"price": {
"display": "价格",
"cell": BaseCell,
"update": False
},
"volume": {
"display": "数量",
"cell": BaseCell,
"update": True
},
"status": {
"display": "状态",
"cell": EnumCell,
"update": True
},
"strategy": {
"display": "策略名",
"cell": StrategyCell,
"update": True
}
} }
def init_ui(self): def init_ui(self):
@ -373,30 +323,21 @@ class StopOrderMonitor(BaseMonitor):
""" """
super(StopOrderMonitor, self).init_ui() super(StopOrderMonitor, self).init_ui()
self.horizontalHeader().setSectionResizeMode( self.horizontalHeader().setSectionResizeMode(QtWidgets.QHeaderView.Stretch)
QtWidgets.QHeaderView.Stretch
)
class LogMonitor(BaseMonitor): class LogMonitor(BaseMonitor):
""" """
Monitor for log data. Monitor for log data.
""" """
event_type = EVENT_CTA_LOG event_type = EVENT_CTA_LOG
data_key = "" data_key = ""
sorting = False sorting = False
headers = { headers = {
"time": { "time": {"display": "时间", "cell": TimeCell, "update": False},
"display": "时间", "msg": {"display": "信息", "cell": MsgCell, "update": False},
"cell": TimeCell,
"update": False
},
"msg": {
"display": "信息",
"cell": MsgCell,
"update": False
}
} }
def init_ui(self): def init_ui(self):
@ -405,10 +346,7 @@ class LogMonitor(BaseMonitor):
""" """
super(LogMonitor, self).init_ui() super(LogMonitor, self).init_ui()
self.horizontalHeader().setSectionResizeMode( self.horizontalHeader().setSectionResizeMode(1, QtWidgets.QHeaderView.Stretch)
1,
QtWidgets.QHeaderView.Stretch
)
def insert_new_row(self, data): def insert_new_row(self, data):
""" """
@ -423,12 +361,7 @@ class SettingEditor(QtWidgets.QDialog):
For creating new strategy and editing strategy parameters. For creating new strategy and editing strategy parameters.
""" """
def __init__( def __init__(self, parameters: dict, strategy_name: str = "", class_name: str = ""):
self,
parameters: dict,
strategy_name: str = "",
class_name: str = ""
):
"""""" """"""
super(SettingEditor, self).__init__() super(SettingEditor, self).__init__()

View File

@ -1 +1 @@
from .engine import Event, EventEngine, EVENT_TIMER from .engine import Event, EventEngine, EVENT_TIMER

View File

@ -1 +1 @@
from .bitmex_gateway import BitmexGateway from .bitmex_gateway import BitmexGateway

View File

@ -31,7 +31,7 @@ from vnpy.trader.object import (
TradeData, TradeData,
PositionData, PositionData,
AccountData, AccountData,
ContractData ContractData,
) )
from vnpy.trader.constant import Direction, Status, PriceType, Exchange, Product from vnpy.trader.constant import Direction, Status, PriceType, Exchange, Product
@ -46,7 +46,7 @@ STATUS_BITMEX2VT = {
"Partially filled": Status.PARTTRADED, "Partially filled": Status.PARTTRADED,
"Filled": Status.ALLTRADED, "Filled": Status.ALLTRADED,
"Canceled": Status.CANCELLED, "Canceled": Status.CANCELLED,
"Rejected": Status.REJECTED "Rejected": Status.REJECTED,
} }
DIRECTION_VT2BITMEX = {Direction.LONG: "Buy", Direction.SHORT: "Sell"} DIRECTION_VT2BITMEX = {Direction.LONG: "Buy", Direction.SHORT: "Sell"}
@ -64,10 +64,9 @@ class BitmexGateway(BaseGateway):
"key": "", "key": "",
"secret": "", "secret": "",
"session": 3, "session": 3,
"server": ["REAL", "server": ["REAL", "TESTNET"],
"TESTNET"],
"proxy_host": "127.0.0.1", "proxy_host": "127.0.0.1",
"proxy_port": 1080 "proxy_port": 1080,
} }
def __init__(self, event_engine): def __init__(self, event_engine):
@ -86,14 +85,7 @@ class BitmexGateway(BaseGateway):
proxy_host = setting["proxy_host"] proxy_host = setting["proxy_host"]
proxy_port = setting["proxy_port"] proxy_port = setting["proxy_port"]
self.rest_api.connect( self.rest_api.connect(key, secret, session, server, proxy_host, proxy_port)
key,
secret,
session,
server,
proxy_host,
proxy_port
)
self.ws_api.connect(key, secret, server, proxy_host, proxy_port) self.ws_api.connect(key, secret, server, proxy_host, proxy_port)
@ -138,7 +130,7 @@ class BitmexRestApi(RestClient):
self.key = "" self.key = ""
self.secret = "" self.secret = ""
self.order_count = 1000000 self.order_count = 1_000_000
self.connect_time = 0 self.connect_time = 0
def sign(self, request): def sign(self, request):
@ -161,9 +153,7 @@ class BitmexRestApi(RestClient):
msg = request.method + "/api/v1" + path + str(expires) + request.data msg = request.method + "/api/v1" + path + str(expires) + request.data
signature = hmac.new( signature = hmac.new(
self.secret, self.secret, msg.encode(), digestmod=hashlib.sha256
msg.encode(),
digestmod=hashlib.sha256
).hexdigest() ).hexdigest()
# Add headers # Add headers
@ -172,20 +162,20 @@ class BitmexRestApi(RestClient):
"Accept": "application/json", "Accept": "application/json",
"api-key": self.key, "api-key": self.key,
"api-expires": str(expires), "api-expires": str(expires),
"api-signature": signature "api-signature": signature,
} }
request.headers = headers request.headers = headers
return request return request
def connect( def connect(
self, self,
key: str, key: str,
secret: str, secret: str,
session: int, session: int,
server: str, server: str,
proxy_host: str, proxy_host: str,
proxy_port: int proxy_port: int,
): ):
""" """
Initialize connection to REST server. Initialize connection to REST server.
@ -193,9 +183,9 @@ class BitmexRestApi(RestClient):
self.key = key self.key = key
self.secret = secret.encode() self.secret = secret.encode()
self.connect_time = int( self.connect_time = (
datetime.now().strftime("%y%m%d%H%M%S") int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count
) * self.order_count )
if server == "REAL": if server == "REAL":
self.init(REST_HOST, proxy_host, proxy_port) self.init(REST_HOST, proxy_host, proxy_port)
@ -204,7 +194,7 @@ class BitmexRestApi(RestClient):
self.start(session) self.start(session)
self.gateway.write_log(u"REST API启动成功") self.gateway.write_log("REST API启动成功")
def send_order(self, req: SubscribeRequest): def send_order(self, req: SubscribeRequest):
"""""" """"""
@ -217,7 +207,7 @@ class BitmexRestApi(RestClient):
"ordType": PRICETYPE_VT2BITMEX[req.price_type], "ordType": PRICETYPE_VT2BITMEX[req.price_type],
"price": req.price, "price": req.price,
"orderQty": int(req.volume), "orderQty": int(req.volume),
"clOrdID": orderid "clOrdID": orderid,
} }
# Only add price for limit order. # Only add price for limit order.
@ -268,11 +258,7 @@ class BitmexRestApi(RestClient):
self.gateway.write_log(msg) self.gateway.write_log(msg)
def on_send_order_error( def on_send_order_error(
self, self, exception_type: type, exception_value: Exception, tb, request: Request
exception_type: type,
exception_value: Exception,
tb,
request: Request
): ):
""" """
Callback when sending order caused exception. Callback when sending order caused exception.
@ -290,11 +276,7 @@ class BitmexRestApi(RestClient):
pass pass
def on_cancel_order_error( def on_cancel_order_error(
self, self, exception_type: type, exception_value: Exception, tb, request: Request
exception_type: type,
exception_value: Exception,
tb,
request: Request
): ):
""" """
Callback when cancelling order failed on server. Callback when cancelling order failed on server.
@ -315,11 +297,7 @@ class BitmexRestApi(RestClient):
self.gateway.write_log(msg) self.gateway.write_log(msg)
def on_error( def on_error(
self, self, exception_type: type, exception_value: Exception, tb, request: Request
exception_type: type,
exception_value: Exception,
tb,
request: Request
): ):
""" """
Callback to handler request exception. Callback to handler request exception.
@ -328,10 +306,7 @@ class BitmexRestApi(RestClient):
self.gateway.write_log(msg) self.gateway.write_log(msg)
sys.stderr.write( sys.stderr.write(
self.exception_detail(exception_type, self.exception_detail(exception_type, exception_value, tb, request)
exception_value,
tb,
request)
) )
@ -355,7 +330,7 @@ class BitmexWebsocketApi(WebsocketClient):
"order": self.on_order, "order": self.on_order,
"position": self.on_position, "position": self.on_position,
"margin": self.on_account, "margin": self.on_account,
"instrument": self.on_contract "instrument": self.on_contract,
} }
self.ticks = {} self.ticks = {}
@ -364,12 +339,7 @@ class BitmexWebsocketApi(WebsocketClient):
self.trades = set() self.trades = set()
def connect( def connect(
self, self, key: str, secret: str, server: str, proxy_host: str, proxy_port: int
key: str,
secret: str,
server: str,
proxy_host: str,
proxy_port: int
): ):
"""""" """"""
self.key = key self.key = key
@ -391,23 +361,23 @@ class BitmexWebsocketApi(WebsocketClient):
exchange=req.exchange, exchange=req.exchange,
name=req.symbol, name=req.symbol,
datetime=datetime.now(), datetime=datetime.now(),
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.ticks[req.symbol] = tick self.ticks[req.symbol] = tick
def on_connected(self): def on_connected(self):
"""""" """"""
self.gateway.write_log(u"Websocket API连接成功") self.gateway.write_log("Websocket API连接成功")
self.authenticate() self.authenticate()
def on_disconnected(self): def on_disconnected(self):
"""""" """"""
self.gateway.write_log(u"Websocket API连接断开") self.gateway.write_log("Websocket API连接断开")
def on_packet(self, packet: dict): def on_packet(self, packet: dict):
"""""" """"""
if "error" in packet: if "error" in packet:
self.gateway.write_log(u"Websocket API报错%s" % packet["error"]) self.gateway.write_log("Websocket API报错%s" % packet["error"])
if "not valid" in packet["error"]: if "not valid" in packet["error"]:
self.active = False self.active = False
@ -418,7 +388,7 @@ class BitmexWebsocketApi(WebsocketClient):
if success: if success:
if req["op"] == "authKey": if req["op"] == "authKey":
self.gateway.write_log(u"Websocket API验证授权成功") self.gateway.write_log("Websocket API验证授权成功")
self.subscribe_topic() self.subscribe_topic()
elif "table" in packet: elif "table" in packet:
@ -436,11 +406,7 @@ class BitmexWebsocketApi(WebsocketClient):
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}" msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
self.gateway.write_log(msg) self.gateway.write_log(msg)
sys.stderr.write( sys.stderr.write(self.exception_detail(exception_type, exception_value, tb))
self.exception_detail(exception_type,
exception_value,
tb)
)
def authenticate(self): def authenticate(self):
""" """
@ -451,9 +417,7 @@ class BitmexWebsocketApi(WebsocketClient):
path = "/realtime" path = "/realtime"
msg = method + path + str(expires) msg = method + path + str(expires)
signature = hmac.new( signature = hmac.new(
self.secret, self.secret, msg.encode(), digestmod=hashlib.sha256
msg.encode(),
digestmod=hashlib.sha256
).hexdigest() ).hexdigest()
req = {"op": "authKey", "args": [self.key, expires, signature]} req = {"op": "authKey", "args": [self.key, expires, signature]}
@ -464,8 +428,7 @@ class BitmexWebsocketApi(WebsocketClient):
Subscribe to all private topics. Subscribe to all private topics.
""" """
req = { req = {
"op": "op": "subscribe",
"subscribe",
"args": [ "args": [
"instrument", "instrument",
"trade", "trade",
@ -473,8 +436,8 @@ class BitmexWebsocketApi(WebsocketClient):
"execution", "execution",
"order", "order",
"position", "position",
"margin" "margin",
] ],
} }
self.send_packet(req) self.send_packet(req)
@ -486,10 +449,7 @@ class BitmexWebsocketApi(WebsocketClient):
return return
tick.last_price = d["price"] tick.last_price = d["price"]
tick.datetime = datetime.strptime( tick.datetime = datetime.strptime(d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
d["timestamp"],
"%Y-%m-%dT%H:%M:%S.%fZ"
)
self.gateway.on_tick(copy(tick)) self.gateway.on_tick(copy(tick))
def on_depth(self, d): def on_depth(self, d):
@ -509,10 +469,7 @@ class BitmexWebsocketApi(WebsocketClient):
tick.__setattr__("ask_price_%s" % (n + 1), price) tick.__setattr__("ask_price_%s" % (n + 1), price)
tick.__setattr__("ask_volume_%s" % (n + 1), volume) tick.__setattr__("ask_volume_%s" % (n + 1), volume)
tick.datetime = datetime.strptime( tick.datetime = datetime.strptime(d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
d["timestamp"],
"%Y-%m-%dT%H:%M:%S.%fZ"
)
self.gateway.on_tick(copy(tick)) self.gateway.on_tick(copy(tick))
def on_trade(self, d): def on_trade(self, d):
@ -540,7 +497,7 @@ class BitmexWebsocketApi(WebsocketClient):
price=d["lastPx"], price=d["lastPx"],
volume=d["lastQty"], volume=d["lastQty"],
time=d["timestamp"][11:19], time=d["timestamp"][11:19],
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.gateway.on_trade(trade) self.gateway.on_trade(trade)
@ -568,7 +525,7 @@ class BitmexWebsocketApi(WebsocketClient):
price=d["price"], price=d["price"],
volume=d["orderQty"], volume=d["orderQty"],
time=d["timestamp"][11:19], time=d["timestamp"][11:19],
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.orders[sysid] = order self.orders[sysid] = order
@ -584,7 +541,7 @@ class BitmexWebsocketApi(WebsocketClient):
exchange=Exchange.BITMEX, exchange=Exchange.BITMEX,
direction=Direction.NET, direction=Direction.NET,
volume=d["currentQty"], volume=d["currentQty"],
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.gateway.on_position(position) self.gateway.on_position(position)
@ -594,10 +551,7 @@ class BitmexWebsocketApi(WebsocketClient):
accountid = str(d["account"]) accountid = str(d["account"])
account = self.accounts.get(accountid, None) account = self.accounts.get(accountid, None)
if not account: if not account:
account = AccountData( account = AccountData(accountid=accountid, gateway_name=self.gateway_name)
accountid=accountid,
gateway_name=self.gateway_name
)
self.accounts[accountid] = account self.accounts[accountid] = account
account.balance = d.get("marginBalance", account.balance) account.balance = d.get("marginBalance", account.balance)
@ -621,7 +575,7 @@ class BitmexWebsocketApi(WebsocketClient):
product=Product.FUTURES, product=Product.FUTURES,
pricetick=d["tickSize"], pricetick=d["tickSize"],
size=d["lotSize"], size=d["lotSize"],
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.gateway.on_contract(contract) self.gateway.on_contract(contract)

View File

@ -1 +1 @@
from .futu_gateway import FutuGateway from .futu_gateway import FutuGateway

View File

@ -22,7 +22,7 @@ from futu import (
StockQuoteHandlerBase, StockQuoteHandlerBase,
OrderBookHandlerBase, OrderBookHandlerBase,
TradeOrderHandlerBase, TradeOrderHandlerBase,
TradeDealHandlerBase TradeDealHandlerBase,
) )
from vnpy.trader.gateway import BaseGateway from vnpy.trader.gateway import BaseGateway
@ -36,14 +36,14 @@ from vnpy.trader.object import (
AccountData, AccountData,
SubscribeRequest, SubscribeRequest,
OrderRequest, OrderRequest,
CancelRequest CancelRequest,
) )
from vnpy.trader.event import EVENT_TIMER from vnpy.trader.event import EVENT_TIMER
EXCHANGE_VT2FUTU = { EXCHANGE_VT2FUTU = {
Exchange.SMART: "US", Exchange.SMART: "US",
Exchange.SEHK: "HK", Exchange.SEHK: "HK",
Exchange.HKFE: "HK_FUTURE" Exchange.HKFE: "HK_FUTURE",
} }
EXCHANGE_FUTU2VT = {v: k for k, v in EXCHANGE_VT2FUTU.items()} EXCHANGE_FUTU2VT = {v: k for k, v in EXCHANGE_VT2FUTU.items()}
@ -52,7 +52,7 @@ PRODUCT_VT2FUTU = {
Product.INDEX: "IDX", Product.INDEX: "IDX",
Product.ETF: "ETF", Product.ETF: "ETF",
Product.WARRANT: "WARRANT", Product.WARRANT: "WARRANT",
Product.BOND: "BOND" Product.BOND: "BOND",
} }
DIRECTION_VT2FUTU = {Direction.LONG: TrdSide.BUY, Direction.SHORT: TrdSide.SELL} DIRECTION_VT2FUTU = {Direction.LONG: TrdSide.BUY, Direction.SHORT: TrdSide.SELL}
@ -79,10 +79,8 @@ class FutuGateway(BaseGateway):
"password": "", "password": "",
"host": "127.0.0.1", "host": "127.0.0.1",
"port": 11111, "port": 11111,
"market": ["HK", "market": ["HK", "US"],
"US"], "env": [TrdEnv.REAL, TrdEnv.SIMULATE],
"env": [TrdEnv.REAL,
TrdEnv.SIMULATE]
} }
def __init__(self, event_engine): def __init__(self, event_engine):
@ -126,7 +124,7 @@ class FutuGateway(BaseGateway):
""" """
Query all data necessary. Query all data necessary.
""" """
sleep(2.0) # Wait 2 seconds till connection completed. sleep(2.0) # Wait 2 seconds till connection completed.
self.query_contract() self.query_contract()
self.query_trade() self.query_trade()
@ -235,7 +233,7 @@ class FutuGateway(BaseGateway):
def send_order(self, req): def send_order(self, req):
"""""" """"""
side = DIRECTION_VT2FUTU[req.direction] side = DIRECTION_VT2FUTU[req.direction]
price_type = OrderType.NORMAL # Only limit order is supported. price_type = OrderType.NORMAL # Only limit order is supported.
# Set price adjustment mode to inside adjustment. # Set price adjustment mode to inside adjustment.
if req.direction is Direction.LONG: if req.direction is Direction.LONG:
@ -251,7 +249,7 @@ class FutuGateway(BaseGateway):
side, side,
price_type, price_type,
trd_env=self.env, trd_env=self.env,
adjust_limit=adjust_limit adjust_limit=adjust_limit,
) )
if code: if code:
@ -268,11 +266,8 @@ class FutuGateway(BaseGateway):
def cancel_order(self, req): def cancel_order(self, req):
"""""" """"""
code, data = self.trade_ctx.modify_order( code, data = self.trade_ctx.modify_order(
ModifyOrderOp.CANCEL, ModifyOrderOp.CANCEL, req.orderid, 0, 0, trd_env=self.env
req.orderid, )
0,
0,
trd_env=self.env)
if code: if code:
self.write_log(f"撤单失败:{data}") self.write_log(f"撤单失败:{data}")
@ -295,7 +290,7 @@ class FutuGateway(BaseGateway):
product=product, product=product,
size=1, size=1,
pricetick=0.001, pricetick=0.001,
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.on_contract(contract) self.on_contract(contract)
self.contracts[contract.vt_symbol] = contract self.contracts[contract.vt_symbol] = contract
@ -314,11 +309,8 @@ class FutuGateway(BaseGateway):
account = AccountData( account = AccountData(
accountid=f"{self.gateway_name}_{self.market}", accountid=f"{self.gateway_name}_{self.market}",
balance=float(row["total_assets"]), balance=float(row["total_assets"]),
frozen=( frozen=(float(row["total_assets"]) - float(row["avl_withdrawal_cash"])),
float(row["total_assets"]) - gateway_name=self.gateway_name,
float(row["avl_withdrawal_cash"])
),
gateway_name=self.gateway_name
) )
self.on_account(account) self.on_account(account)
@ -340,7 +332,7 @@ class FutuGateway(BaseGateway):
frozen=(float(row["qty"]) - float(row["can_sell_qty"])), frozen=(float(row["qty"]) - float(row["can_sell_qty"])),
price=float(row["pl_val"]), price=float(row["pl_val"]),
pnl=float(row["cost_price"]), pnl=float(row["cost_price"]),
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.on_position(pos) self.on_position(pos)
@ -386,7 +378,7 @@ class FutuGateway(BaseGateway):
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
datetime=datetime.now(), datetime=datetime.now(),
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.ticks[code] = tick self.ticks[code] = tick
@ -405,10 +397,7 @@ class FutuGateway(BaseGateway):
date = row["data_date"].replace("-", "") date = row["data_date"].replace("-", "")
time = row["data_time"] time = row["data_time"]
tick.datetime = datetime.strptime( tick.datetime = datetime.strptime(f"{date} {time}", "%Y%m%d %H:%M:%S")
f"{date} {time}",
"%Y%m%d %H:%M:%S"
)
tick.open_price = row["open_price"] tick.open_price = row["open_price"]
tick.high_price = row["high_price"] tick.high_price = row["high_price"]
tick.low_price = row["low_price"] tick.low_price = row["low_price"]
@ -462,7 +451,7 @@ class FutuGateway(BaseGateway):
traded=float(row["dealt_qty"]), traded=float(row["dealt_qty"]),
status=STATUS_FUTU2VT[row["order_status"]], status=STATUS_FUTU2VT[row["order_status"]],
time=row["create_time"].split(" ")[-1], time=row["create_time"].split(" ")[-1],
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.on_order(order) self.on_order(order)
@ -487,7 +476,7 @@ class FutuGateway(BaseGateway):
price=float(row["price"]), price=float(row["price"]),
volume=float(row["qty"]), volume=float(row["qty"]),
time=row["create_time"].split(" ")[-1], time=row["create_time"].split(" ")[-1],
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.on_trade(trade) self.on_trade(trade)

View File

@ -1 +1 @@
from .ib_gateway import IbGateway from .ib_gateway import IbGateway

View File

@ -27,7 +27,7 @@ from vnpy.trader.object import (
AccountData, AccountData,
SubscribeRequest, SubscribeRequest,
OrderRequest, OrderRequest,
CancelRequest CancelRequest,
) )
from vnpy.trader.constant import ( from vnpy.trader.constant import (
Product, Product,
@ -36,7 +36,7 @@ from vnpy.trader.constant import (
Exchange, Exchange,
Currency, Currency,
Status, Status,
OptionType OptionType,
) )
PRICETYPE_VT2IB = {PriceType.LIMIT: "LMT", PriceType.MARKET: "MKT"} PRICETYPE_VT2IB = {PriceType.LIMIT: "LMT", PriceType.MARKET: "MKT"}
@ -55,7 +55,7 @@ EXCHANGE_VT2IB = {
Exchange.CME: "CME", Exchange.CME: "CME",
Exchange.ICE: "ICE", Exchange.ICE: "ICE",
Exchange.SEHK: "SEHK", Exchange.SEHK: "SEHK",
Exchange.HKFE: "HKFE" Exchange.HKFE: "HKFE",
} }
EXCHANGE_IB2VT = {v: k for k, v in EXCHANGE_VT2IB.items()} EXCHANGE_IB2VT = {v: k for k, v in EXCHANGE_VT2IB.items()}
@ -64,7 +64,7 @@ STATUS_IB2VT = {
"Filled": Status.ALLTRADED, "Filled": Status.ALLTRADED,
"Cancelled": Status.CANCELLED, "Cancelled": Status.CANCELLED,
"PendingSubmit": Status.SUBMITTING, "PendingSubmit": Status.SUBMITTING,
"PreSubmitted": Status.NOTTRADED "PreSubmitted": Status.NOTTRADED,
} }
PRODUCT_VT2IB = { PRODUCT_VT2IB = {
@ -72,7 +72,7 @@ PRODUCT_VT2IB = {
Product.FOREX: "CASH", Product.FOREX: "CASH",
Product.SPOT: "CMDTY", Product.SPOT: "CMDTY",
Product.OPTION: "OPT", Product.OPTION: "OPT",
Product.FUTURES: "FUT" Product.FUTURES: "FUT",
} }
PRODUCT_IB2VT = {v: k for k, v in PRODUCT_VT2IB.items()} PRODUCT_IB2VT = {v: k for k, v in PRODUCT_VT2IB.items()}
@ -91,7 +91,7 @@ TICKFIELD_IB2VT = {
7: "low_price", 7: "low_price",
8: "volume", 8: "volume",
9: "pre_close", 9: "pre_close",
14: "open_price" 14: "open_price",
} }
ACCOUNTFIELD_IB2VT = { ACCOUNTFIELD_IB2VT = {
@ -182,21 +182,21 @@ class IbApi(EWrapper):
self.client = IbClient(self) self.client = IbClient(self)
self.thread = Thread(target=self.client.run) self.thread = Thread(target=self.client.run)
def connectAck(self): # pylint: disable=invalid-name def connectAck(self): # pylint: disable=invalid-name
""" """
Callback when connection is established. Callback when connection is established.
""" """
self.status = True self.status = True
self.gateway.write_log("IB TWS连接成功") self.gateway.write_log("IB TWS连接成功")
def connectionClosed(self): # pylint: disable=invalid-name def connectionClosed(self): # pylint: disable=invalid-name
""" """
Callback when connection is closed. Callback when connection is closed.
""" """
self.status = False self.status = False
self.gateway.write_log("IB TWS连接断开") self.gateway.write_log("IB TWS连接断开")
def nextValidId(self, orderId: int): # pylint: disable=invalid-name def nextValidId(self, orderId: int): # pylint: disable=invalid-name
""" """
Callback of next valid orderid. Callback of next valid orderid.
""" """
@ -204,7 +204,7 @@ class IbApi(EWrapper):
self.orderid = orderId self.orderid = orderId
def currentTime(self, time: int): # pylint: disable=invalid-name def currentTime(self, time: int): # pylint: disable=invalid-name
""" """
Callback of current server time of IB. Callback of current server time of IB.
""" """
@ -216,7 +216,9 @@ class IbApi(EWrapper):
msg = f"服务器时间: {time_string}" msg = f"服务器时间: {time_string}"
self.gateway.write_log(msg) self.gateway.write_log(msg)
def error(self, reqId: TickerId, errorCode: int, errorString: str): # pylint: disable=invalid-name def error(
self, reqId: TickerId, errorCode: int, errorString: str
): # pylint: disable=invalid-name
""" """
Callback of error caused by specific request. Callback of error caused by specific request.
""" """
@ -226,11 +228,7 @@ class IbApi(EWrapper):
self.gateway.write_log(msg) self.gateway.write_log(msg)
def tickPrice( # pylint: disable=invalid-name def tickPrice( # pylint: disable=invalid-name
self, self, reqId: TickerId, tickType: TickType, price: float, attrib: TickAttrib
reqId: TickerId,
tickType: TickType,
price: float,
attrib: TickAttrib
): ):
""" """
Callback of tick price update. Callback of tick price update.
@ -257,7 +255,9 @@ class IbApi(EWrapper):
tick.datetime = datetime.now() tick.datetime = datetime.now()
self.gateway.on_tick(copy(tick)) self.gateway.on_tick(copy(tick))
def tickSize(self, reqId: TickerId, tickType: TickType, size: int): # pylint: disable=invalid-name def tickSize(
self, reqId: TickerId, tickType: TickType, size: int
): # pylint: disable=invalid-name
""" """
Callback of tick volume update. Callback of tick volume update.
""" """
@ -272,7 +272,9 @@ class IbApi(EWrapper):
self.gateway.on_tick(copy(tick)) self.gateway.on_tick(copy(tick))
def tickString(self, reqId: TickerId, tickType: TickType, value: str): # pylint: disable=invalid-name def tickString(
self, reqId: TickerId, tickType: TickType, value: str
): # pylint: disable=invalid-name
""" """
Callback of tick string update. Callback of tick string update.
""" """
@ -287,36 +289,35 @@ class IbApi(EWrapper):
self.gateway.on_tick(copy(tick)) self.gateway.on_tick(copy(tick))
def orderStatus( # pylint: disable=invalid-name def orderStatus( # pylint: disable=invalid-name
self, self,
orderId: OrderId, orderId: OrderId,
status: str, status: str,
filled: float, filled: float,
remaining: float, remaining: float,
avgFillPrice: float, avgFillPrice: float,
permId: int, permId: int,
parentId: int, parentId: int,
lastFillPrice: float, lastFillPrice: float,
clientId: int, clientId: int,
whyHeld: str, whyHeld: str,
mktCapPrice: float mktCapPrice: float,
): ):
""" """
Callback of order status update. Callback of order status update.
""" """
super(IbApi, super(IbApi, self).orderStatus(
self).orderStatus( orderId,
orderId, status,
status, filled,
filled, remaining,
remaining, avgFillPrice,
avgFillPrice, permId,
permId, parentId,
parentId, lastFillPrice,
lastFillPrice, clientId,
clientId, whyHeld,
whyHeld, mktCapPrice,
mktCapPrice )
)
orderid = str(orderId) orderid = str(orderId)
order = self.orders.get(orderid, None) order = self.orders.get(orderid, None)
@ -326,11 +327,11 @@ class IbApi(EWrapper):
self.gateway.on_order(copy(order)) self.gateway.on_order(copy(order))
def openOrder( # pylint: disable=invalid-name def openOrder( # pylint: disable=invalid-name
self, self,
orderId: OrderId, orderId: OrderId,
ib_contract: Contract, ib_contract: Contract,
ib_order: Order, ib_order: Order,
orderState: OrderState orderState: OrderState,
): ):
""" """
Callback when opening new order. Callback when opening new order.
@ -340,26 +341,19 @@ class IbApi(EWrapper):
orderid = str(orderId) orderid = str(orderId)
order = OrderData( order = OrderData(
symbol=ib_contract.conId, symbol=ib_contract.conId,
exchange=EXCHANGE_IB2VT.get( exchange=EXCHANGE_IB2VT.get(ib_contract.exchange, ib_contract.exchange),
ib_contract.exchange,
ib_contract.exchange
),
orderid=orderid, orderid=orderid,
direction=DIRECTION_IB2VT[ib_order.action], direction=DIRECTION_IB2VT[ib_order.action],
price=ib_order.lmtPrice, price=ib_order.lmtPrice,
volume=ib_order.totalQuantity, volume=ib_order.totalQuantity,
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.orders[orderid] = order self.orders[orderid] = order
self.gateway.on_order(copy(order)) self.gateway.on_order(copy(order))
def updateAccountValue( # pylint: disable=invalid-name def updateAccountValue( # pylint: disable=invalid-name
self, self, key: str, val: str, currency: str, accountName: str
key: str,
val: str,
currency: str,
accountName: str
): ):
""" """
Callback of account update. Callback of account update.
@ -372,54 +366,49 @@ class IbApi(EWrapper):
accountid = f"{accountName}.{currency}" accountid = f"{accountName}.{currency}"
account = self.accounts.get(accountid, None) account = self.accounts.get(accountid, None)
if not account: if not account:
account = AccountData( account = AccountData(accountid=accountid, gateway_name=self.gateway_name)
accountid=accountid,
gateway_name=self.gateway_name
)
self.accounts[accountid] = account self.accounts[accountid] = account
name = ACCOUNTFIELD_IB2VT[key] name = ACCOUNTFIELD_IB2VT[key]
setattr(account, name, float(val)) setattr(account, name, float(val))
def updatePortfolio( # pylint: disable=invalid-name def updatePortfolio( # pylint: disable=invalid-name
self, self,
contract: Contract, contract: Contract,
position: float, position: float,
marketPrice: float, marketPrice: float,
marketValue: float, marketValue: float,
averageCost: float, averageCost: float,
unrealizedPNL: float, unrealizedPNL: float,
realizedPNL: float, realizedPNL: float,
accountName: str accountName: str,
): ):
""" """
Callback of position update. Callback of position update.
""" """
super(IbApi, super(IbApi, self).updatePortfolio(
self).updatePortfolio( contract,
contract, position,
position, marketPrice,
marketPrice, marketValue,
marketValue, averageCost,
averageCost, unrealizedPNL,
unrealizedPNL, realizedPNL,
realizedPNL, accountName,
accountName )
)
pos = PositionData( pos = PositionData(
symbol=contract.conId, symbol=contract.conId,
exchange=EXCHANGE_IB2VT.get(contract.exchange, exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange),
contract.exchange),
direction=DIRECTION_NET, direction=DIRECTION_NET,
volume=position, volume=position,
price=averageCost, price=averageCost,
pnl=unrealizedPNL, pnl=unrealizedPNL,
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.gateway.on_position(pos) self.gateway.on_position(pos)
def updateAccountTime(self, timeStamp: str): # pylint: disable=invalid-name def updateAccountTime(self, timeStamp: str): # pylint: disable=invalid-name
""" """
Callback of account update time. Callback of account update time.
""" """
@ -427,7 +416,9 @@ class IbApi(EWrapper):
for account in self.accounts.values(): for account in self.accounts.values():
self.gateway.on_account(copy(account)) self.gateway.on_account(copy(account))
def contractDetails(self, reqId: int, contractDetails: ContractDetails): # pylint: disable=invalid-name def contractDetails(
self, reqId: int, contractDetails: ContractDetails
): # pylint: disable=invalid-name
""" """
Callback of contract data update. Callback of contract data update.
""" """
@ -443,20 +434,21 @@ class IbApi(EWrapper):
contract = ContractData( contract = ContractData(
symbol=ib_symbol, symbol=ib_symbol,
exchange=EXCHANGE_IB2VT.get(ib_exchange, exchange=EXCHANGE_IB2VT.get(ib_exchange, ib_exchange),
ib_exchange),
name=contractDetails.longName, name=contractDetails.longName,
product=PRODUCT_IB2VT[ib_product], product=PRODUCT_IB2VT[ib_product],
size=ib_size, size=ib_size,
pricetick=contractDetails.minTick, pricetick=contractDetails.minTick,
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.gateway.on_contract(contract) self.gateway.on_contract(contract)
self.contracts[contract.vt_symbol] = contract self.contracts[contract.vt_symbol] = contract
def execDetails(self, reqId: int, contract: Contract, execution: Execution): # pylint: disable=invalid-name def execDetails(
self, reqId: int, contract: Contract, execution: Execution
): # pylint: disable=invalid-name
""" """
Callback of trade data update. Callback of trade data update.
""" """
@ -465,21 +457,19 @@ class IbApi(EWrapper):
today_date = datetime.now().strftime("%Y%m%d") today_date = datetime.now().strftime("%Y%m%d")
trade = TradeData( trade = TradeData(
symbol=contract.conId, symbol=contract.conId,
exchange=EXCHANGE_IB2VT.get(contract.exchange, exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange),
contract.exchange),
orderid=str(execution.orderId), orderid=str(execution.orderId),
tradeid=str(execution.execId), tradeid=str(execution.execId),
direction=DIRECTION_IB2VT[execution.side], direction=DIRECTION_IB2VT[execution.side],
price=execution.price, price=execution.price,
volume=execution.shares, volume=execution.shares,
time=datetime.strptime(execution.time, time=datetime.strptime(execution.time, "%Y%m%d %H:%M:%S"),
"%Y%m%d %H:%M:%S"), gateway_name=self.gateway_name,
gateway_name=self.gateway_name
) )
self.gateway.on_trade(trade) self.gateway.on_trade(trade)
def managedAccounts(self, accountsList: str): # pylint: disable=invalid-name def managedAccounts(self, accountsList: str): # pylint: disable=invalid-name
""" """
Callback of all sub accountid. Callback of all sub accountid.
""" """
@ -497,11 +487,7 @@ class IbApi(EWrapper):
self.clientid = setting["clientid"] self.clientid = setting["clientid"]
self.client.connect( self.client.connect(setting["host"], setting["port"], setting["clientid"])
setting["host"],
setting["port"],
setting["clientid"]
)
self.thread.start() self.thread.start()
@ -544,7 +530,7 @@ class IbApi(EWrapper):
symbol=req.symbol, symbol=req.symbol,
exchange=req.exchange, exchange=req.exchange,
datetime=datetime.now(), datetime=datetime.now(),
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
self.ticks[self.reqid] = tick self.ticks[self.reqid] = tick
self.tick_exchange[self.reqid] = req.exchange self.tick_exchange[self.reqid] = req.exchange

View File

@ -7,10 +7,11 @@ class BaseApp(ABC):
""" """
Absstract class for app. Absstract class for app.
""" """
app_name = "" # Unique name used for creating engine and widget
app_module = "" # App module string used in import_module app_name = "" # Unique name used for creating engine and widget
app_path = "" # Absolute path of app folder app_module = "" # App module string used in import_module
display_name = "" # Name for display on the menu. app_path = "" # Absolute path of app folder
engine_class = None # App engine class display_name = "" # Name for display on the menu.
widget_name = "" # Class name of app widget engine_class = None # App engine class
icon_name = "" # Icon file name of app widget widget_name = "" # Class name of app widget
icon_name = "" # Icon file name of app widget

View File

@ -9,6 +9,7 @@ class Direction(Enum):
""" """
Direction of order/trade/position. Direction of order/trade/position.
""" """
LONG = "" LONG = ""
SHORT = "" SHORT = ""
NET = "" NET = ""
@ -18,6 +19,7 @@ class Offset(Enum):
""" """
Offset of order/trade. Offset of order/trade.
""" """
NONE = "" NONE = ""
OPEN = "" OPEN = ""
CLOSE = "" CLOSE = ""
@ -29,6 +31,7 @@ class Status(Enum):
""" """
Order status. Order status.
""" """
SUBMITTING = "提交中" SUBMITTING = "提交中"
NOTTRADED = "未成交" NOTTRADED = "未成交"
PARTTRADED = "部分成交" PARTTRADED = "部分成交"
@ -41,6 +44,7 @@ class Product(Enum):
""" """
Product class. Product class.
""" """
EQUITY = "股票" EQUITY = "股票"
FUTURES = "期货" FUTURES = "期货"
OPTION = "期权" OPTION = "期权"
@ -56,6 +60,7 @@ class PriceType(Enum):
""" """
Order price type. Order price type.
""" """
LIMIT = "限价" LIMIT = "限价"
MARKET = "市价" MARKET = "市价"
FAK = "FAK" FAK = "FAK"
@ -66,6 +71,7 @@ class OptionType(Enum):
""" """
Option type. Option type.
""" """
CALL = "看涨期权" CALL = "看涨期权"
PUT = "看跌期权" PUT = "看跌期权"
@ -74,6 +80,7 @@ class Exchange(Enum):
""" """
Exchange. Exchange.
""" """
# Chinese # Chinese
CFFEX = "CFFEX" CFFEX = "CFFEX"
SHFE = "SHFE" SHFE = "SHFE"
@ -102,6 +109,7 @@ class Currency(Enum):
""" """
Currency. Currency.
""" """
USD = "USD" USD = "USD"
HKD = "HKD" HKD = "HKD"
CNY = "CNY" CNY = "CNY"
@ -111,4 +119,4 @@ class Interval(Enum):
MINUTE = "1m" MINUTE = "1m"
HOUR = "1h" HOUR = "1h"
DAILY = "d" DAILY = "d"
WEEKLY = "w" WEEKLY = "w"

View File

@ -6,7 +6,7 @@ from peewee import (
CharField, CharField,
DateTimeField, DateTimeField,
FloatField, FloatField,
IntegerField IntegerField,
) )
from .utility import get_temp_path from .utility import get_temp_path
@ -23,6 +23,7 @@ class DbBarData(Model):
Index is defined unique with vt_symbol, interval and datetime. Index is defined unique with vt_symbol, interval and datetime.
""" """
symbol = CharField() symbol = CharField()
exchange = CharField() exchange = CharField()
datetime = DateTimeField() datetime = DateTimeField()
@ -39,7 +40,7 @@ class DbBarData(Model):
class Meta: class Meta:
database = DB database = DB
indexes = ((('vt_symbol', 'interval', 'datetime'), True),) indexes = ((("vt_symbol", "interval", "datetime"), True),)
@staticmethod @staticmethod
def from_bar(bar: BarData): def from_bar(bar: BarData):
@ -76,7 +77,7 @@ class DbBarData(Model):
high_price=high_price, high_price=high_price,
low_price=low_price, low_price=low_price,
close_price=close_price, close_price=close_price,
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
return bar return bar
@ -87,6 +88,7 @@ class DbTickData(Model):
Index is defined unique with vt_symbol, interval and datetime. Index is defined unique with vt_symbol, interval and datetime.
""" """
symbol = CharField() symbol = CharField()
exchange = CharField() exchange = CharField()
datetime = DateTimeField() datetime = DateTimeField()
@ -132,7 +134,7 @@ class DbTickData(Model):
class Meta: class Meta:
database = DB database = DB
indexes = ((('vt_symbol', 'datetime'), True),) indexes = ((("vt_symbol", "datetime"), True),)
@staticmethod @staticmethod
def from_tick(tick: TickData): def from_tick(tick: TickData):
@ -208,7 +210,7 @@ class DbTickData(Model):
ask_price_1=self.ask_price_1, ask_price_1=self.ask_price_1,
bid_volume_1=self.bid_volume_1, bid_volume_1=self.bid_volume_1,
ask_volume_1=self.ask_volume_1, ask_volume_1=self.ask_volume_1,
gateway_name=self.gateway_name gateway_name=self.gateway_name,
) )
if self.bid_price_2: if self.bid_price_2:

View File

@ -19,7 +19,7 @@ from .event import (
EVENT_TRADE, EVENT_TRADE,
EVENT_POSITION, EVENT_POSITION,
EVENT_ACCOUNT, EVENT_ACCOUNT,
EVENT_CONTRACT EVENT_CONTRACT,
) )
from .object import LogData, SubscribeRequest, OrderRequest, CancelRequest from .object import LogData, SubscribeRequest, OrderRequest, CancelRequest
from .utility import Singleton, get_temp_path from .utility import Singleton, get_temp_path
@ -180,10 +180,7 @@ class BaseEngine(ABC):
""" """
def __init__( def __init__(
self, self, main_engine: MainEngine, event_engine: EventEngine, engine_name: str
main_engine: MainEngine,
event_engine: EventEngine,
engine_name: str
): ):
"""""" """"""
self.main_engine = main_engine self.main_engine = main_engine
@ -211,9 +208,7 @@ class LogEngine(BaseEngine):
self.level = SETTINGS["log.level"] self.level = SETTINGS["log.level"]
self.logger = logging.getLogger("VN Trader") self.logger = logging.getLogger("VN Trader")
self.formatter = logging.Formatter( self.formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
"%(asctime)s %(levelname)s: %(message)s"
)
self.add_null_handler() self.add_null_handler()
@ -249,7 +244,7 @@ class LogEngine(BaseEngine):
filename = f"vt_{today_date}.log" filename = f"vt_{today_date}.log"
file_path = get_temp_path(filename) file_path = get_temp_path(filename)
file_handler = logging.FileHandler(file_path, mode='w', encoding='utf8') file_handler = logging.FileHandler(file_path, mode="w", encoding="utf8")
file_handler.setLevel(self.level) file_handler.setLevel(self.level)
file_handler.setFormatter(self.formatter) file_handler.setFormatter(self.formatter)
self.logger.addHandler(file_handler) self.logger.addHandler(file_handler)
@ -421,7 +416,7 @@ class OmsEngine(BaseEngine):
""" """
return list(self.contracts.values()) return list(self.contracts.values())
def get_all_active_orders(self, vt_symbol: str = ''): def get_all_active_orders(self, vt_symbol: str = ""):
""" """
Get all active orders by vt_symbol. Get all active orders by vt_symbol.
@ -431,7 +426,8 @@ class OmsEngine(BaseEngine):
return list(self.active_orders.values()) return list(self.active_orders.values())
else: else:
active_orders = [ active_orders = [
order for order in self.active_orders.values() order
for order in self.active_orders.values()
if order.vt_symbol == vt_symbol if order.vt_symbol == vt_symbol
] ]
return active_orders return active_orders
@ -476,12 +472,10 @@ class EmailEngine(BaseEngine):
try: try:
msg = self.queue.get(block=True, timeout=1) msg = self.queue.get(block=True, timeout=1)
with smtplib.SMTP_SSL(SETTINGS["email.server"], with smtplib.SMTP_SSL(
SETTINGS["email.port"]) as smtp: SETTINGS["email.server"], SETTINGS["email.port"]
smtp.login( ) as smtp:
SETTINGS["email.username"], smtp.login(SETTINGS["email.username"], SETTINGS["email.password"])
SETTINGS["email.password"]
)
smtp.send_message(msg) smtp.send_message(msg)
except Empty: except Empty:
pass pass

View File

@ -4,10 +4,10 @@ Event type string used in VN Trader.
from vnpy.event import EVENT_TIMER from vnpy.event import EVENT_TIMER
EVENT_TICK = 'eTick.' EVENT_TICK = "eTick."
EVENT_TRADE = 'eTrade.' EVENT_TRADE = "eTrade."
EVENT_ORDER = 'eOrder.' EVENT_ORDER = "eOrder."
EVENT_POSITION = 'ePosition.' EVENT_POSITION = "ePosition."
EVENT_ACCOUNT = 'eAccount.' EVENT_ACCOUNT = "eAccount."
EVENT_CONTRACT = 'eContract.' EVENT_CONTRACT = "eContract."
EVENT_LOG = 'eLog' EVENT_LOG = "eLog"

View File

@ -14,7 +14,7 @@ from .event import (
EVENT_ACCOUNT, EVENT_ACCOUNT,
EVENT_POSITION, EVENT_POSITION,
EVENT_LOG, EVENT_LOG,
EVENT_CONTRACT EVENT_CONTRACT,
) )
from .object import ( from .object import (
TickData, TickData,
@ -26,7 +26,7 @@ from .object import (
ContractData, ContractData,
SubscribeRequest, SubscribeRequest,
OrderRequest, OrderRequest,
CancelRequest CancelRequest,
) )
@ -163,4 +163,4 @@ class BaseGateway(ABC):
""" """
Return default setting dict. Return default setting dict.
""" """
return self.default_setting return self.default_setting

View File

@ -17,6 +17,7 @@ class BaseData:
Any data object needs a gateway_name as source or Any data object needs a gateway_name as source or
destination and should inherit base data. destination and should inherit base data.
""" """
gateway_name: str gateway_name: str
@ -28,6 +29,7 @@ class TickData(BaseData):
* orderbook snapshot * orderbook snapshot
* intraday market statistics. * intraday market statistics.
""" """
symbol: str symbol: str
exchange: Exchange exchange: Exchange
datetime: datetime datetime: datetime
@ -78,6 +80,7 @@ class BarData(BaseData):
""" """
Candlestick bar data of a certain trading period. Candlestick bar data of a certain trading period.
""" """
symbol: str symbol: str
exchange: Exchange exchange: Exchange
datetime: datetime datetime: datetime
@ -100,6 +103,7 @@ class OrderData(BaseData):
Order data contains information for tracking lastest status Order data contains information for tracking lastest status
of a specific order. of a specific order.
""" """
symbol: str symbol: str
exchange: Exchange exchange: Exchange
orderid: str orderid: str
@ -131,9 +135,7 @@ class OrderData(BaseData):
Create cancel request object from order. Create cancel request object from order.
""" """
req = CancelRequest( req = CancelRequest(
orderid=self.orderid, orderid=self.orderid, symbol=self.symbol, exchange=self.exchange
symbol=self.symbol,
exchange=self.exchange
) )
return req return req
@ -144,6 +146,7 @@ class TradeData(BaseData):
Trade data contains information of a fill of an order. One order Trade data contains information of a fill of an order. One order
can have several trade fills. can have several trade fills.
""" """
symbol: str symbol: str
exchange: Exchange exchange: Exchange
orderid: str orderid: str
@ -167,6 +170,7 @@ class PositionData(BaseData):
""" """
Positon data is used for tracking each individual position holding. Positon data is used for tracking each individual position holding.
""" """
symbol: str symbol: str
exchange: Exchange exchange: Exchange
direction: Direction direction: Direction
@ -188,6 +192,7 @@ class AccountData(BaseData):
Account data contains information about balance, frozen and Account data contains information about balance, frozen and
available. available.
""" """
accountid: str accountid: str
balance: float = 0 balance: float = 0
@ -204,6 +209,7 @@ class LogData(BaseData):
""" """
Log data is used for recording log messages on GUI or in log files. Log data is used for recording log messages on GUI or in log files.
""" """
msg: str msg: str
level: int = INFO level: int = INFO
@ -217,6 +223,7 @@ class ContractData(BaseData):
""" """
Contract data contains basic information about each contract traded. Contract data contains basic information about each contract traded.
""" """
symbol: str symbol: str
exchange: Exchange exchange: Exchange
name: str name: str
@ -225,8 +232,8 @@ class ContractData(BaseData):
pricetick: float pricetick: float
option_strike: float = 0 option_strike: float = 0
option_underlying: str = '' # vt_symbol of underlying contract option_underlying: str = "" # vt_symbol of underlying contract
option_type: str = '' option_type: str = ""
option_expiry: datetime = None option_expiry: datetime = None
def __post_init__(self): def __post_init__(self):
@ -239,6 +246,7 @@ class SubscribeRequest:
""" """
Request sending to specific gateway for subscribing tick data update. Request sending to specific gateway for subscribing tick data update.
""" """
symbol: str symbol: str
exchange: Exchange exchange: Exchange
@ -252,6 +260,7 @@ class OrderRequest:
""" """
Request sending to specific gateway for creating a new order. Request sending to specific gateway for creating a new order.
""" """
symbol: str symbol: str
exchange: Exchange exchange: Exchange
direction: Direction direction: Direction
@ -276,7 +285,7 @@ class OrderRequest:
offset=self.offset, offset=self.offset,
price=self.price, price=self.price,
volume=self.volume, volume=self.volume,
gateway_name=gateway_name gateway_name=gateway_name,
) )
return order return order
@ -286,6 +295,7 @@ class CancelRequest:
""" """
Request sending to specific gateway for canceling an existing order. Request sending to specific gateway for canceling an existing order.
""" """
orderid: str orderid: str
symbol: str symbol: str
exchange: Exchange exchange: Exchange

View File

@ -16,5 +16,5 @@ SETTINGS = {
"email.username": "", "email.username": "",
"email.password": "", "email.password": "",
"email.sender": "", "email.sender": "",
"email.receiver": "" "email.receiver": "",
} }

View File

@ -14,13 +14,8 @@ from ..utility import get_icon_path
def excepthook(exctype, value, tb): def excepthook(exctype, value, tb):
"""异常捕捉钩子""" """异常捕捉钩子"""
msg = ''.join(traceback.format_exception(exctype, value, tb)) msg = "".join(traceback.format_exception(exctype, value, tb))
QtWidgets.QMessageBox.critical( QtWidgets.QMessageBox.critical(None, u"Exception", msg, QtWidgets.QMessageBox.Ok)
None,
u'Exception',
msg,
QtWidgets.QMessageBox.Ok
)
def create_qapp(): def create_qapp():
@ -38,9 +33,7 @@ def create_qapp():
icon = QtGui.QIcon(get_icon_path(__file__, "vnpy.ico")) icon = QtGui.QIcon(get_icon_path(__file__, "vnpy.ico"))
qapp.setWindowIcon(icon) qapp.setWindowIcon(icon)
if 'Windows' in platform.uname(): if "Windows" in platform.uname():
ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID( ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("VN Trader")
'VN Trader'
)
return qapp return qapp

View File

@ -22,7 +22,7 @@ from .widget import (
TradingWidget, TradingWidget,
ActiveOrderMonitor, ActiveOrderMonitor,
ContractManager, ContractManager,
AboutDialog AboutDialog,
) )
@ -54,14 +54,30 @@ class MainWindow(QtWidgets.QMainWindow):
def init_dock(self): def init_dock(self):
"""""" """"""
trading_widget, trading_dock = self.create_dock(TradingWidget, "交易", QtCore.Qt.LeftDockWidgetArea) trading_widget, trading_dock = self.create_dock(
tick_widget, tick_dock = self.create_dock(TickMonitor, "行情", QtCore.Qt.RightDockWidgetArea) TradingWidget, "交易", QtCore.Qt.LeftDockWidgetArea
order_widget, order_dock = self.create_dock(OrderMonitor, "委托", QtCore.Qt.RightDockWidgetArea) )
active_widget, active_dock = self.create_dock(ActiveOrderMonitor, "活动", QtCore.Qt.RightDockWidgetArea) tick_widget, tick_dock = self.create_dock(
trade_widget, trade_dock = self.create_dock(TradeMonitor, "成交", QtCore.Qt.RightDockWidgetArea) TickMonitor, "行情", QtCore.Qt.RightDockWidgetArea
log_widget, log_dock = self.create_dock(LogMonitor, "日志", QtCore.Qt.BottomDockWidgetArea) )
account_widget, account_dock = self.create_dock(AccountMonitor, "资金", QtCore.Qt.BottomDockWidgetArea) order_widget, order_dock = self.create_dock(
position_widget, position_dock = self.create_dock(PositionMonitor, "持仓", QtCore.Qt.BottomDockWidgetArea) OrderMonitor, "委托", QtCore.Qt.RightDockWidgetArea
)
active_widget, active_dock = self.create_dock(
ActiveOrderMonitor, "活动", QtCore.Qt.RightDockWidgetArea
)
trade_widget, trade_dock = self.create_dock(
TradeMonitor, "成交", QtCore.Qt.RightDockWidgetArea
)
log_widget, log_dock = self.create_dock(
LogMonitor, "日志", QtCore.Qt.BottomDockWidgetArea
)
account_widget, account_dock = self.create_dock(
AccountMonitor, "资金", QtCore.Qt.BottomDockWidgetArea
)
position_widget, position_dock = self.create_dock(
PositionMonitor, "持仓", QtCore.Qt.BottomDockWidgetArea
)
self.tabifyDockWidget(active_dock, order_dock) self.tabifyDockWidget(active_dock, order_dock)
@ -96,52 +112,31 @@ class MainWindow(QtWidgets.QMainWindow):
func = partial(self.open_widget, widget_class, app.app_name) func = partial(self.open_widget, widget_class, app.app_name)
icon_path = str(app.app_path.joinpath("ui", app.icon_name)) icon_path = str(app.app_path.joinpath("ui", app.icon_name))
self.add_menu_action( self.add_menu_action(app_menu, f"打开{app.display_name}", icon_path, func)
app_menu,
f"打开{app.display_name}",
icon_path,
func
)
# Help menu # Help menu
self.add_menu_action( self.add_menu_action(
help_menu, help_menu,
"查询合约", "查询合约",
"contract.ico", "contract.ico",
partial(self.open_widget, partial(self.open_widget, ContractManager, "contract"),
ContractManager,
"contract")
) )
self.add_menu_action( self.add_menu_action(
help_menu, help_menu, "还原窗口", "restore.ico", self.restore_window_setting
"还原窗口",
"restore.ico",
self.restore_window_setting
) )
self.add_menu_action( self.add_menu_action(help_menu, "测试邮件", "email.ico", self.send_test_email)
help_menu,
"测试邮件",
"email.ico",
self.send_test_email
)
self.add_menu_action( self.add_menu_action(
help_menu, help_menu,
"关于", "关于",
"about.ico", "about.ico",
partial(self.open_widget, partial(self.open_widget, AboutDialog, "about"),
AboutDialog,
"about")
) )
def add_menu_action( def add_menu_action(
self, self, menu: QtWidgets.QMenu, action_name: str, icon_name: str, func: Callable
menu: QtWidgets.QMenu,
action_name: str,
icon_name: str,
func: Callable
): ):
"""""" """"""
icon = QtGui.QIcon(get_icon_path(__file__, icon_name)) icon = QtGui.QIcon(get_icon_path(__file__, icon_name))
@ -152,12 +147,7 @@ class MainWindow(QtWidgets.QMainWindow):
menu.addAction(action) menu.addAction(action)
def create_dock( def create_dock(self, widget_class: QtWidgets.QWidget, name: str, area: int):
self,
widget_class: QtWidgets.QWidget,
name: str,
area: int
):
""" """
Initialize a dock widget. Initialize a dock widget.
""" """
@ -189,7 +179,7 @@ class MainWindow(QtWidgets.QMainWindow):
"退出", "退出",
"确认退出?", "确认退出?",
QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No, QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
QtWidgets.QMessageBox.No QtWidgets.QMessageBox.No,
) )
if reply == QtWidgets.QMessageBox.Yes: if reply == QtWidgets.QMessageBox.Yes:

View File

@ -18,7 +18,7 @@ from ..event import (
EVENT_ACCOUNT, EVENT_ACCOUNT,
EVENT_POSITION, EVENT_POSITION,
EVENT_CONTRACT, EVENT_CONTRACT,
EVENT_LOG EVENT_LOG,
) )
from ..object import SubscribeRequest, OrderRequest, CancelRequest from ..object import SubscribeRequest, OrderRequest, CancelRequest
from ..utility import load_setting, save_setting from ..utility import load_setting, save_setting
@ -299,22 +299,19 @@ class BaseMonitor(QtWidgets.QTableWidget):
""" """
Resize all columns according to contents. Resize all columns according to contents.
""" """
self.horizontalHeader().resizeSections( self.horizontalHeader().resizeSections(QtWidgets.QHeaderView.ResizeToContents)
QtWidgets.QHeaderView.ResizeToContents
)
def save_csv(self): def save_csv(self):
""" """
Save table data into a csv file Save table data into a csv file
""" """
path, _ = QtWidgets.QFileDialog.getSaveFileName(self, "保存数据", "", path, _ = QtWidgets.QFileDialog.getSaveFileName(self, "保存数据", "", "CSV(*.csv)")
"CSV(*.csv)")
if not path: if not path:
return return
with open(path, "w") as f: with open(path, "w") as f:
writer = csv.writer(f, lineterminator='\n') writer = csv.writer(f, lineterminator="\n")
writer.writerow(self.headers.keys()) writer.writerow(self.headers.keys())
@ -339,81 +336,26 @@ class TickMonitor(BaseMonitor):
""" """
Monitor for tick data. Monitor for tick data.
""" """
event_type = EVENT_TICK event_type = EVENT_TICK
data_key = "vt_symbol" data_key = "vt_symbol"
sorting = True sorting = True
headers = { headers = {
"symbol": { "symbol": {"display": "代码", "cell": BaseCell, "update": False},
"display": "代码", "exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"cell": BaseCell, "name": {"display": "名称", "cell": BaseCell, "update": True},
"update": False "last_price": {"display": "最新价", "cell": BaseCell, "update": True},
}, "volume": {"display": "成交量", "cell": BaseCell, "update": True},
"exchange": { "open_price": {"display": "开盘价", "cell": BaseCell, "update": True},
"display": "交易所", "high_price": {"display": "最高价", "cell": BaseCell, "update": True},
"cell": EnumCell, "low_price": {"display": "最低价", "cell": BaseCell, "update": True},
"update": False "bid_price_1": {"display": "买1价", "cell": BidCell, "update": True},
}, "bid_volume_1": {"display": "买1量", "cell": BidCell, "update": True},
"name": { "ask_price_1": {"display": "卖1价", "cell": AskCell, "update": True},
"display": "名称", "ask_volume_1": {"display": "卖1量", "cell": AskCell, "update": True},
"cell": BaseCell, "datetime": {"display": "时间", "cell": TimeCell, "update": True},
"update": True "gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
},
"last_price": {
"display": "最新价",
"cell": BaseCell,
"update": True
},
"volume": {
"display": "成交量",
"cell": BaseCell,
"update": True
},
"open_price": {
"display": "开盘价",
"cell": BaseCell,
"update": True
},
"high_price": {
"display": "最高价",
"cell": BaseCell,
"update": True
},
"low_price": {
"display": "最低价",
"cell": BaseCell,
"update": True
},
"bid_price_1": {
"display": "买1价",
"cell": BidCell,
"update": True
},
"bid_volume_1": {
"display": "买1量",
"cell": BidCell,
"update": True
},
"ask_price_1": {
"display": "卖1价",
"cell": AskCell,
"update": True
},
"ask_volume_1": {
"display": "卖1量",
"cell": AskCell,
"update": True
},
"datetime": {
"display": "时间",
"cell": TimeCell,
"update": True
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
} }
@ -421,26 +363,15 @@ class LogMonitor(BaseMonitor):
""" """
Monitor for log data. Monitor for log data.
""" """
event_type = EVENT_LOG event_type = EVENT_LOG
data_key = "" data_key = ""
sorting = False sorting = False
headers = { headers = {
"time": { "time": {"display": "时间", "cell": TimeCell, "update": False},
"display": "时间", "msg": {"display": "信息", "cell": MsgCell, "update": False},
"cell": TimeCell, "gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
"update": False
},
"msg": {
"display": "信息",
"cell": MsgCell,
"update": False
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
} }
@ -448,61 +379,22 @@ class TradeMonitor(BaseMonitor):
""" """
Monitor for trade data. Monitor for trade data.
""" """
event_type = EVENT_TRADE event_type = EVENT_TRADE
data_key = "" data_key = ""
sorting = True sorting = True
headers = { headers = {
"tradeid": { "tradeid": {"display": "成交号 ", "cell": BaseCell, "update": False},
"display": "成交号 ", "orderid": {"display": "委托号", "cell": BaseCell, "update": False},
"cell": BaseCell, "symbol": {"display": "代码", "cell": BaseCell, "update": False},
"update": False "exchange": {"display": "交易所", "cell": EnumCell, "update": False},
}, "direction": {"display": "方向", "cell": DirectionCell, "update": False},
"orderid": { "offset": {"display": "开平", "cell": EnumCell, "update": False},
"display": "委托号", "price": {"display": "价格", "cell": BaseCell, "update": False},
"cell": BaseCell, "volume": {"display": "数量", "cell": BaseCell, "update": False},
"update": False "time": {"display": "时间", "cell": BaseCell, "update": False},
}, "gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
"symbol": {
"display": "代码",
"cell": BaseCell,
"update": False
},
"exchange": {
"display": "交易所",
"cell": EnumCell,
"update": False
},
"direction": {
"display": "方向",
"cell": DirectionCell,
"update": False
},
"offset": {
"display": "开平",
"cell": EnumCell,
"update": False
},
"price": {
"display": "价格",
"cell": BaseCell,
"update": False
},
"volume": {
"display": "数量",
"cell": BaseCell,
"update": False
},
"time": {
"display": "时间",
"cell": BaseCell,
"update": False
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
} }
@ -510,66 +402,23 @@ class OrderMonitor(BaseMonitor):
""" """
Monitor for order data. Monitor for order data.
""" """
event_type = EVENT_ORDER event_type = EVENT_ORDER
data_key = "vt_orderid" data_key = "vt_orderid"
sorting = True sorting = True
headers = { headers = {
"orderid": { "orderid": {"display": "委托号", "cell": BaseCell, "update": False},
"display": "委托号", "symbol": {"display": "代码", "cell": BaseCell, "update": False},
"cell": BaseCell, "exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"update": False "direction": {"display": "方向", "cell": DirectionCell, "update": False},
}, "offset": {"display": "开平", "cell": EnumCell, "update": False},
"symbol": { "price": {"display": "价格", "cell": BaseCell, "update": False},
"display": "代码", "volume": {"display": "总数量", "cell": BaseCell, "update": True},
"cell": BaseCell, "traded": {"display": "已成交", "cell": BaseCell, "update": True},
"update": False "status": {"display": "状态", "cell": EnumCell, "update": True},
}, "time": {"display": "时间", "cell": BaseCell, "update": True},
"exchange": { "gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
"display": "交易所",
"cell": EnumCell,
"update": False
},
"direction": {
"display": "方向",
"cell": DirectionCell,
"update": False
},
"offset": {
"display": "开平",
"cell": EnumCell,
"update": False
},
"price": {
"display": "价格",
"cell": BaseCell,
"update": False
},
"volume": {
"display": "总数量",
"cell": BaseCell,
"update": True
},
"traded": {
"display": "已成交",
"cell": BaseCell,
"update": True
},
"status": {
"display": "状态",
"cell": EnumCell,
"update": True
},
"time": {
"display": "时间",
"cell": BaseCell,
"update": True
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
} }
def init_ui(self): def init_ui(self):
@ -594,51 +443,20 @@ class PositionMonitor(BaseMonitor):
""" """
Monitor for position data. Monitor for position data.
""" """
event_type = EVENT_POSITION event_type = EVENT_POSITION
data_key = "vt_positionid" data_key = "vt_positionid"
sorting = True sorting = True
headers = { headers = {
"symbol": { "symbol": {"display": "代码", "cell": BaseCell, "update": False},
"display": "代码", "exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"cell": BaseCell, "direction": {"display": "方向", "cell": DirectionCell, "update": False},
"update": False "volume": {"display": "数量", "cell": BaseCell, "update": True},
}, "frozen": {"display": "冻结", "cell": BaseCell, "update": True},
"exchange": { "price": {"display": "均价", "cell": BaseCell, "update": False},
"display": "交易所", "pnl": {"display": "盈亏", "cell": PnlCell, "update": True},
"cell": EnumCell, "gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
"update": False
},
"direction": {
"display": "方向",
"cell": DirectionCell,
"update": False
},
"volume": {
"display": "数量",
"cell": BaseCell,
"update": True
},
"frozen": {
"display": "冻结",
"cell": BaseCell,
"update": True
},
"price": {
"display": "均价",
"cell": BaseCell,
"update": False
},
"pnl": {
"display": "盈亏",
"cell": PnlCell,
"update": True
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
} }
@ -646,36 +464,17 @@ class AccountMonitor(BaseMonitor):
""" """
Monitor for account data. Monitor for account data.
""" """
event_type = EVENT_ACCOUNT event_type = EVENT_ACCOUNT
data_key = "vt_accountid" data_key = "vt_accountid"
sorting = True sorting = True
headers = { headers = {
"accountid": { "accountid": {"display": "账号", "cell": BaseCell, "update": False},
"display": "账号", "balance": {"display": "余额", "cell": BaseCell, "update": True},
"cell": BaseCell, "frozen": {"display": "冻结", "cell": BaseCell, "update": True},
"update": False "available": {"display": "可用", "cell": BaseCell, "update": True},
}, "gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
"balance": {
"display": "余额",
"cell": BaseCell,
"update": True
},
"frozen": {
"display": "冻结",
"cell": BaseCell,
"update": True
},
"available": {
"display": "可用",
"cell": BaseCell,
"update": True
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
} }
@ -701,9 +500,7 @@ class ConnectDialog(QtWidgets.QDialog):
self.setWindowTitle(f"连接{self.gateway_name}") self.setWindowTitle(f"连接{self.gateway_name}")
# Default setting provides field name, field data type and field default value. # Default setting provides field name, field data type and field default value.
default_setting = self.main_engine.get_default_setting( default_setting = self.main_engine.get_default_setting(self.gateway_name)
self.gateway_name
)
# Saved setting provides field data used last time. # Saved setting provides field data used last time.
loaded_setting = load_setting(self.filename) loaded_setting = load_setting(self.filename)
@ -732,7 +529,7 @@ class ConnectDialog(QtWidgets.QDialog):
form.addRow(f"{field_name} <{field_type.__name__}>", widget) form.addRow(f"{field_name} <{field_type.__name__}>", widget)
self.widgets[field_name] = (widget, field_type) self.widgets[field_name] = (widget, field_type)
button = QtWidgets.QPushButton(u"连接") button = QtWidgets.QPushButton("连接")
button.clicked.connect(self.connect) button.clicked.connect(self.connect)
form.addRow(button) form.addRow(button)
@ -792,18 +589,13 @@ class TradingWidget(QtWidgets.QWidget):
self.name_line.setReadOnly(True) self.name_line.setReadOnly(True)
self.direction_combo = QtWidgets.QComboBox() self.direction_combo = QtWidgets.QComboBox()
self.direction_combo.addItems( self.direction_combo.addItems([Direction.LONG.value, Direction.SHORT.value])
[Direction.LONG.value,
Direction.SHORT.value]
)
self.offset_combo = QtWidgets.QComboBox() self.offset_combo = QtWidgets.QComboBox()
self.offset_combo.addItems([offset.value for offset in Offset]) self.offset_combo.addItems([offset.value for offset in Offset])
self.price_type_combo = QtWidgets.QComboBox() self.price_type_combo = QtWidgets.QComboBox()
self.price_type_combo.addItems( self.price_type_combo.addItems([price_type.value for price_type in PriceType])
[price_type.value for price_type in PriceType]
)
double_validator = QtGui.QDoubleValidator() double_validator = QtGui.QDoubleValidator()
double_validator.setBottom(0) double_validator.setBottom(0)
@ -846,26 +638,11 @@ class TradingWidget(QtWidgets.QWidget):
self.bp4_label = self.create_label(bid_color) self.bp4_label = self.create_label(bid_color)
self.bp5_label = self.create_label(bid_color) self.bp5_label = self.create_label(bid_color)
self.bv1_label = self.create_label( self.bv1_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight)
bid_color, self.bv2_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight)
alignment=QtCore.Qt.AlignRight self.bv3_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight)
) self.bv4_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight)
self.bv2_label = self.create_label( self.bv5_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight)
bid_color,
alignment=QtCore.Qt.AlignRight
)
self.bv3_label = self.create_label(
bid_color,
alignment=QtCore.Qt.AlignRight
)
self.bv4_label = self.create_label(
bid_color,
alignment=QtCore.Qt.AlignRight
)
self.bv5_label = self.create_label(
bid_color,
alignment=QtCore.Qt.AlignRight
)
self.ap1_label = self.create_label(ask_color) self.ap1_label = self.create_label(ask_color)
self.ap2_label = self.create_label(ask_color) self.ap2_label = self.create_label(ask_color)
@ -873,26 +650,11 @@ class TradingWidget(QtWidgets.QWidget):
self.ap4_label = self.create_label(ask_color) self.ap4_label = self.create_label(ask_color)
self.ap5_label = self.create_label(ask_color) self.ap5_label = self.create_label(ask_color)
self.av1_label = self.create_label( self.av1_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight)
ask_color, self.av2_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight)
alignment=QtCore.Qt.AlignRight self.av3_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight)
) self.av4_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight)
self.av2_label = self.create_label( self.av5_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight)
ask_color,
alignment=QtCore.Qt.AlignRight
)
self.av3_label = self.create_label(
ask_color,
alignment=QtCore.Qt.AlignRight
)
self.av4_label = self.create_label(
ask_color,
alignment=QtCore.Qt.AlignRight
)
self.av5_label = self.create_label(
ask_color,
alignment=QtCore.Qt.AlignRight
)
self.lp_label = self.create_label() self.lp_label = self.create_label()
self.return_label = self.create_label(alignment=QtCore.Qt.AlignRight) self.return_label = self.create_label(alignment=QtCore.Qt.AlignRight)
@ -916,11 +678,7 @@ class TradingWidget(QtWidgets.QWidget):
vbox.addLayout(form2) vbox.addLayout(form2)
self.setLayout(vbox) self.setLayout(vbox)
def create_label( def create_label(self, color: str = "", alignment: int = QtCore.Qt.AlignLeft):
self,
color: str = "",
alignment: int = QtCore.Qt.AlignLeft
):
""" """
Create label with certain font color. Create label with certain font color.
""" """
@ -992,7 +750,7 @@ class TradingWidget(QtWidgets.QWidget):
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
if not contract: if not contract:
self.name_line.setText("") self.name_line.setText("")
gateway_name = (self.gateway_combo.currentText()) gateway_name = self.gateway_combo.currentText()
else: else:
self.name_line.setText(contract.name) self.name_line.setText(contract.name)
gateway_name = contract.gateway_name gateway_name = contract.gateway_name
@ -1067,7 +825,7 @@ class TradingWidget(QtWidgets.QWidget):
price_type=PriceType(str(self.price_type_combo.currentText())), price_type=PriceType(str(self.price_type_combo.currentText())),
volume=volume, volume=volume,
price=price, price=price,
offset=Offset(str(self.offset_combo.currentText())) offset=Offset(str(self.offset_combo.currentText())),
) )
gateway_name = str(self.gateway_combo.currentText()) gateway_name = str(self.gateway_combo.currentText())
@ -1118,7 +876,7 @@ class ContractManager(QtWidgets.QWidget):
"product": "合约分类", "product": "合约分类",
"size": "合约乘数", "size": "合约乘数",
"pricetick": "价格跳动", "pricetick": "价格跳动",
"gateway_name": "交易接口" "gateway_name": "交易接口",
} }
def __init__(self, main_engine, event_engine): def __init__(self, main_engine, event_engine):
@ -1171,8 +929,7 @@ class ContractManager(QtWidgets.QWidget):
all_contracts = self.main_engine.get_all_contracts() all_contracts = self.main_engine.get_all_contracts()
if flt: if flt:
contracts = [ contracts = [
contract for contract in all_contracts contract for contract in all_contracts if flt in contract.vt_symbol
if flt in contract.vt_symbol
] ]
else: else:
contracts = all_contracts contracts = all_contracts

View File

@ -14,14 +14,13 @@ class Singleton(type):
__metaclass__ = Singleton __metaclass__ = Singleton
""" """
_instances = {} _instances = {}
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
"""""" """"""
if cls not in cls._instances: if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
cls).__call__(*args,
**kwargs)
return cls._instances[cls] return cls._instances[cls]
@ -39,7 +38,7 @@ def get_temp_path(filename: str):
Get path for temp file with filename. Get path for temp file with filename.
""" """
trader_path = get_trader_path() trader_path = get_trader_path()
temp_path = trader_path.joinpath('.vntrader') temp_path = trader_path.joinpath(".vntrader")
if not temp_path.exists(): if not temp_path.exists():
temp_path.mkdir() temp_path.mkdir()