[Mod]change formatting tools to black

This commit is contained in:
vn.py 2019-01-26 17:24:38 +08:00
parent f9fd309098
commit 5d4e975ff4
32 changed files with 634 additions and 1077 deletions

View File

@ -3,3 +3,7 @@ qdarkstyle
futu-api
websocket-client
peewee
numpy
pandas
matplotlib
seaborn

View File

@ -1,4 +1,4 @@
import unittest
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -5,20 +5,21 @@ from yapf.yapflib.yapf_api import FormatFile
logger = logging.Logger(__file__)
if __name__ == '__main__':
if __name__ == "__main__":
has_changed = False
for root, dir, filenames in os.walk("vnpy"):
for filename in filenames:
basename, ext = os.path.splitext(filename)
if ext == '.py':
if ext == ".py":
path = os.path.join(root, filename)
reformatted_code, encoding, changed = FormatFile(filename=path,
style_config='.style.yapf',
print_diff=True,
verify=False,
in_place=False,
logger=None
)
reformatted_code, encoding, changed = FormatFile(
filename=path,
style_config=".style.yapf",
print_diff=True,
verify=False,
in_place=False,
logger=None,
)
if changed:
has_changed = True
logger.warning("File {} not formatted!".format(path))

View File

@ -12,10 +12,10 @@ from typing import Any, Callable, Optional
class RequestStatus(Enum):
ready = 0 # Request created
success = 1 # Request successful (status code 2xx)
ready = 0 # Request created
success = 1 # Request successful (status code 2xx)
failed = 2 # Request failed (status code not 2xx)
error = 3 # Exception raised
error = 3 # Exception raised
class Request(object):
@ -24,16 +24,16 @@ class Request(object):
"""
def __init__(
self,
method: str,
path: str,
params: dict,
data: dict,
headers: dict,
callback: Callable,
on_failed: Callable = None,
on_error: Callable = None,
extra: Any = None
self,
method: str,
path: str,
params: dict,
data: dict,
headers: dict,
callback: Callable,
on_failed: Callable = None,
on_error: Callable = None,
extra: Any = None,
):
""""""
self.method = method
@ -52,7 +52,7 @@ class Request(object):
def __str__(self):
if self.response is None:
status_code = 'terminated'
status_code = "terminated"
else:
status_code = self.response.status_code
@ -70,7 +70,7 @@ class Request(object):
self.headers,
self.params,
self.data,
'' if self.response is None else self.response.text
"" if self.response is None else self.response.text,
)
)
@ -88,11 +88,11 @@ class RestClient(object):
def __init__(self):
"""
"""
self.url_base = None # type: str
self.url_base = None # type: str
self._active = False
self._queue = Queue()
self._pool = None # type: Pool
self._pool = None # type: Pool
self.proxies = None
@ -135,16 +135,16 @@ class RestClient(object):
self._queue.join()
def add_request(
self,
method: str,
path: str,
callback: Callable,
params: dict = None,
data: dict = None,
headers: dict = None,
on_failed: Callable = None,
on_error: Callable = None,
extra: Any = None
self,
method: str,
path: str,
callback: Callable,
params: dict = None,
data: dict = None,
headers: dict = None,
on_failed: Callable = None,
on_error: Callable = None,
extra: Any = None,
):
"""
Add a new request.
@ -160,15 +160,7 @@ class RestClient(object):
:return: Request
"""
request = Request(
method,
path,
params,
data,
headers,
callback,
on_failed,
on_error,
extra
method, path, params, data, headers, callback, on_failed, on_error, extra
)
self._queue.put(request)
return request
@ -204,54 +196,34 @@ class RestClient(object):
sys.stderr.write(str(request))
def on_error(
self,
exception_type: type,
exception_value: Exception,
tb,
request: Request
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Default on_error handler for Python exception.
"""
sys.stderr.write(
self.exception_detail(exception_type,
exception_value,
tb,
request)
self.exception_detail(exception_type, exception_value, tb, request)
)
sys.excepthook(exception_type, exception_value, tb)
def exception_detail(
self,
exception_type: type,
exception_value: Exception,
tb,
request: Request
self, exception_type: type, exception_value: Exception, tb, request: Request
):
text = "[{}]: Unhandled RestClient Error:{}\n".format(
datetime.now().isoformat(),
exception_type
datetime.now().isoformat(), exception_type
)
text += "request:{}\n".format(request)
text += "Exception trace: \n"
text += "".join(
traceback.format_exception(
exception_type,
exception_value,
tb,
)
)
text += "".join(traceback.format_exception(exception_type, exception_value, tb))
return text
def _process_request(
self,
request: Request,
session: requests.session
): # type: (Request, requests.Session)->None
self, request: Request, session: requests.session
): # type: (Request, requests.Session)->None
"""
Sending request to server and get result.
"""
# noinspection PyBroadException
# noinspection PyBroadException
try:
request = self.sign(request)
@ -263,12 +235,12 @@ class RestClient(object):
headers=request.headers,
params=request.params,
data=request.data,
proxies=self.proxies
proxies=self.proxies,
)
request.response = response
status_code = response.status_code
if status_code / 100 == 2: # 2xx都算成功尽管交易所都用200
if status_code / 100 == 2: # 2xx都算成功尽管交易所都用200
jsonBody = response.json()
request.callback(jsonBody, request)
request.status = RequestStatus.success

View File

@ -123,9 +123,9 @@ class WebsocketClient(object):
""""""
self._ws = self._create_connection(
self.host,
sslopt={'cert_reqs': ssl.CERT_NONE},
sslopt={"cert_reqs": ssl.CERT_NONE},
http_proxy_host=self.proxy_host,
http_proxy_port=self.proxy_port
http_proxy_port=self.proxy_port,
)
self.on_connected()
@ -166,7 +166,7 @@ class WebsocketClient(object):
try:
data = self.unpack_data(text)
except ValueError as e:
print('websocket unable to parse data: ' + text)
print("websocket unable to parse data: " + text)
raise e
self.on_packet(data)
@ -211,7 +211,7 @@ class WebsocketClient(object):
""""""
ws = self._get_ws()
if ws:
ws.send('ping', websocket.ABNF.OPCODE_PING)
ws.send("ping", websocket.ABNF.OPCODE_PING)
@staticmethod
def on_connected():
@ -238,36 +238,20 @@ class WebsocketClient(object):
"""
Callback when exception raised.
"""
sys.stderr.write(
self.exception_detail(exception_type,
exception_value,
tb)
)
sys.stderr.write(self.exception_detail(exception_type, exception_value, tb))
return sys.excepthook(exception_type, exception_value, tb)
def exception_detail(
self,
exception_type: type,
exception_value: Exception,
tb
):
def exception_detail(self, exception_type: type, exception_value: Exception, tb):
"""
Print detailed exception information.
"""
text = "[{}]: Unhandled WebSocket Error:{}\n".format(
datetime.now().isoformat(),
exception_type
datetime.now().isoformat(), exception_type
)
text += "LastSentText:\n{}\n".format(self._last_sent_text)
text += "LastReceivedText:\n{}\n".format(self._last_received_text)
text += "Exception trace: \n"
text += "".join(
traceback.format_exception(
exception_type,
exception_value,
tb,
)
)
text += "".join(traceback.format_exception(exception_type, exception_value, tb))
return text
def _record_last_sent_text(self, text: str):

View File

@ -8,6 +8,7 @@ from .base import APP_NAME
class CtaStrategyApp(BaseApp):
""""""
app_name = APP_NAME
app_module = __module__
app_path = Path(__file__).parent

View File

@ -4,6 +4,7 @@ from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pandas import DataFrame
from vnpy.trader.constant import Interval, Status, Direction, Exchange
@ -18,13 +19,16 @@ from .base import (
EngineType,
StopOrder,
BacktestingMode,
ORDER_CTA2VT
ORDER_CTA2VT,
)
from .template import CtaTemplate
sns.set_style("whitegrid")
class BacktestingEngine:
""""""
engine_type = EngineType.BACKTESTING
gateway_name = "BACKTESTING"
@ -39,7 +43,7 @@ class BacktestingEngine:
self.slippage = 0
self.size = 1
self.pricetick = 0
self.capital = 1000000
self.capital = 1_000_000
self.mode = BacktestingMode.BAR
self.strategy = None
@ -92,17 +96,17 @@ class BacktestingEngine:
self.daily_results.clear()
def set_parameters(
self,
vt_symbol: str,
interval: Interval,
start: datetime,
rate: float,
slippage: float,
size: float,
pricetick: float,
capital: int = 0,
end: datetime = None,
mode: BacktestingMode = None
self,
vt_symbol: str,
interval: Interval,
start: datetime,
rate: float,
slippage: float,
size: float,
pricetick: float,
capital: int = 0,
end: datetime = None,
mode: BacktestingMode = None,
):
""""""
self.mode = mode
@ -124,10 +128,7 @@ class BacktestingEngine:
def add_strategy(self, strategy_class: type, setting: dict):
""""""
self.strategy = strategy_class(
self,
strategy_class.__name__,
self.vt_symbol,
setting
self, strategy_class.__name__, self.vt_symbol, setting
)
def load_data(self):
@ -135,18 +136,26 @@ class BacktestingEngine:
self.output("开始加载历史数据")
if self.mode == BacktestingMode.BAR:
s = DbBarData.select().where(
(DbBarData.vt_symbol == self.vt_symbol) &
(DbBarData.interval == self.interval) &
(DbBarData.datetime >= self.start) &
(DbBarData.datetime <= self.end)
).order_by(DbBarData.datetime)
s = (
DbBarData.select()
.where(
(DbBarData.vt_symbol == self.vt_symbol)
& (DbBarData.interval == self.interval)
& (DbBarData.datetime >= self.start)
& (DbBarData.datetime <= self.end)
)
.order_by(DbBarData.datetime)
)
else:
s = DbTickData.select().where(
(DbTickData.vt_symbol == self.vt_symbol) &
(DbTickData.datetime >= self.start) &
(DbTickData.datetime <= self.end)
).order_by(DbTickData.datetime)
s = (
DbTickData.select()
.where(
(DbTickData.vt_symbol == self.vt_symbol)
& (DbTickData.datetime >= self.start)
& (DbTickData.datetime <= self.end)
)
.order_by(DbTickData.datetime)
)
self.history_data = list(s)
@ -164,7 +173,7 @@ class BacktestingEngine:
# Use the first [days] of history data for initializing strategy
day_count = 0
for ix, data in enumerate(self.history_data):
if (self.datetime and data.datetime.day != self.datetime.day):
if self.datetime and data.datetime.day != self.datetime.day:
day_count += 1
if day_count >= self.days:
break
@ -205,11 +214,7 @@ class BacktestingEngine:
for daily_result in self.daily_results.values():
daily_result.calculate_pnl(
pre_close,
start_pos,
self.size,
self.rate,
self.slippage
pre_close, start_pos, self.size, self.rate, self.slippage
)
pre_close = daily_result.close_price
@ -236,8 +241,12 @@ class BacktestingEngine:
# Calculate balance related time series data
df["balance"] = df["net_pnl"].cumsum() + self.capital
df["return"] = (np.log(df["balance}" - np.log(df["balance"].shift(1))).fillna(0)
df["highlevel"] = df["balance"].rolling(min_periods=1,window=len(df),center=False).max()
df["return"] = (np.log(df["balance"] - np.log(df["balance"].shift(1)))).fillna(
0
)
df["highlevel"] = (
df["balance"].rolling(min_periods=1, window=len(df), center=False).max()
)
df["drawdown"] = df["balance"] - df["highlevel"]
df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100
@ -246,8 +255,8 @@ class BacktestingEngine:
end_date = df.index[-1]
total_days = len(df)
profit_days = len(df[df["netPnl"]>0])
loss_days = len(df[df["netPnl"]<0])
profit_days = len(df[df["netPnl"] > 0])
loss_days = len(df[df["netPnl"] < 0])
end_balance = df["balance"].iloc[-1]
max_drawdown = df["drawdown"].min()
@ -268,7 +277,7 @@ class BacktestingEngine:
total_trade_count = df["trade_count"].sum()
daily_trade_count = total_trade_count / total_days
total_return = (end_balance/self.capital - 1) * 100
total_return = (end_balance / self.capital - 1) * 100
annual_return = total_return / total_days * 240
daily_return = df["return"].mean() * 100
return_std = df["return"].std() * 100
@ -279,7 +288,7 @@ class BacktestingEngine:
sharpe_ratio = 0
# Output
# 输出统计结果
# 输出统计结果
self.output("-" * 30)
self.output(f"首个交易日:\t{start_date}")
self.output(f"最后交易日:\t{end_date}")
@ -335,7 +344,7 @@ class BacktestingEngine:
"annual_return": annual_return,
"daily_return": daily_return,
"return_std": return_std,
"sharpe_ratio": sharpe_ratio
"sharpe_ratio": sharpe_ratio,
}
return statistics
@ -348,20 +357,20 @@ class BacktestingEngine:
fig = plt.figure(figsize=(10, 16))
balance_plot = plt.subplot(4, 1, 1)
balance_plot.set_title('Balance')
df['balance'].plot(legend=True)
balance_plot.set_title("Balance")
df["balance"].plot(legend=True)
drawdown_plot = plt.subplot(4, 1, 2)
drawdown_plot.set_title('Drawdown')
drawdown_plot.fill_between(range(len(df)), df['drawdown'].values)
drawdown_plot.set_title("Drawdown")
drawdown_plot.fill_between(range(len(df)), df["drawdown"].values)
pnl_plot = plt.subplot(4, 1, 3)
pnl_plot.set_title('Daily Pnl')
df['net_pnl'].plot(kind='bar', legend=False, grid=False, xticks=[])
pnl_plot.set_title("Daily Pnl")
df["net_pnl"].plot(kind="bar", legend=False, grid=False, xticks=[])
distribution_plot = plt.subplot(4, 1, 4)
distribution_plot.set_title('Daily Pnl Distribution')
df['net_pnl'].hist(bins=50)
distribution_plot.set_title("Daily Pnl Distribution")
df["net_pnl"].hist(bins=50)
plt.show()
@ -421,12 +430,14 @@ class BacktestingEngine:
# Check whether limit orders can be filled.
long_cross = (
order.direction == Direction.LONG
and order.price >= long_cross_price and long_cross_price > 0
and order.price >= long_cross_price
and long_cross_price > 0
)
short_cross = (
order.direction == Direction.SHORT
and order.price <= short_cross_price and short_cross_price > 0
and order.price <= short_cross_price
and short_cross_price > 0
)
if not long_cross and not short_cross:
@ -459,7 +470,7 @@ class BacktestingEngine:
price=trade_price,
volume=order.volume,
time=self.datetime.strftime("%H:%M:%S"),
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
trade.datetime = self.datetime
@ -510,7 +521,7 @@ class BacktestingEngine:
price=stop_order.price,
volume=stop_order.volume,
status=Status.ALLTRADED,
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.limit_orders[order.vt_orderid] = order
@ -535,7 +546,7 @@ class BacktestingEngine:
price=trade_price,
volume=order.volume,
time=self.datetime.strftime("%H:%M:%S"),
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
trade.datetime = self.datetime
@ -553,11 +564,7 @@ class BacktestingEngine:
self.strategy.on_trade(trade)
def load_bar(
self,
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable
self, vt_symbol: str, days: int, interval: Interval, callback: Callable
):
""""""
self.days = days
@ -569,12 +576,12 @@ class BacktestingEngine:
self.callback = callback
def send_order(
self,
strategy: CtaTemplate,
order_type: CtaOrderType,
price: float,
volume: float,
stop: bool = False
self,
strategy: CtaTemplate,
order_type: CtaOrderType,
price: float,
volume: float,
stop: bool = False,
):
""""""
if stop:
@ -582,12 +589,7 @@ class BacktestingEngine:
else:
return self.send_limit_order(order_type, price, volume)
def send_stop_order(
self,
order_type: CtaOrderType,
price: float,
volume: float
):
def send_stop_order(self, order_type: CtaOrderType, price: float, volume: float):
""""""
self.stop_order_count += 1
@ -597,7 +599,7 @@ class BacktestingEngine:
price=price,
volume=volume,
stop_orderid=f"{STOPORDER_PREFIX}.{self.stop_order_count}",
strategy_name=self.strategy.strategy_name
strategy_name=self.strategy.strategy_name,
)
self.active_stop_orders[stop_order.stop_orderid] = stop_order
@ -605,12 +607,7 @@ class BacktestingEngine:
return stop_order.stop_orderid
def send_limit_order(
self,
order_type: CtaOrderType,
price: float,
volume: float
):
def send_limit_order(self, order_type: CtaOrderType, price: float, volume: float):
""""""
self.limit_order_count += 1
direction, offset = ORDER_CTA2VT[order_type]
@ -624,7 +621,7 @@ class BacktestingEngine:
price=price,
volume=volume,
status=Status.NOTTRADED,
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.active_limit_orders[order.vt_orderid] = order
@ -724,19 +721,17 @@ class DailyResult:
self.trades.append(trade)
def calculate_pnl(
self,
pre_close: float,
start_pos: float,
size: int,
rate: float,
slippage: float
self,
pre_close: float,
start_pos: float,
size: int,
rate: float,
slippage: float,
):
""""""
# Holding pnl is the pnl from holding position at day start
self.start_pos = self.end_pos = start_pos
self.holding_pnl = self.start_pos * (
self.close_price - self.pre_close
) * size
self.holding_pnl = self.start_pos * (self.close_price - self.pre_close) * size
# Trading pnl is the pnl from new trade during the day
self.trade_count = len(self.trades)
@ -749,9 +744,7 @@ class DailyResult:
pos_change -= trade.volume
turnover = trade.price * trade.volume * size
self.trading_pnl += pos_change * (
self.close_price - trade.price
) * size
self.trading_pnl += pos_change * (self.close_price - trade.price) * size
self.end_pos += pos_change
self.turnover += turnover
self.commission += turnover * rate
@ -760,3 +753,39 @@ class DailyResult:
# Net pnl takes account of commission and slippage cost
self.total_pnl = self.trading_pnl + self.holding_pnl
self.net_pnl = self.total_pnl - self.commission - self.slippage
class OptimizationSetting:
"""
Setting for runnning optimization.
"""
def __init__(self):
""""""
self.params = {}
self.target = ""
def add_parameter(
self, name: str, start: float, end: float = None, step: float = None
):
""""""
if not end and not step:
self.params[name] = [start]
return
if start >= end:
print("参数优化起始点必须小于终止点")
return
if step <= 0:
print("参数优化步进必须大于0")
return
value = start
value_list = []
while value <= end:
value_list.append(value)
value += step
self.params[name] = value_list

View File

@ -51,17 +51,13 @@ class StopOrder:
self.direction, self.offset = ORDER_CTA2VT[self.order_type]
EVENT_CTA_LOG = 'eCtaLog'
EVENT_CTA_STRATEGY = 'eCtaStrategy'
EVENT_CTA_STOPORDER = 'eCtaStopOrder'
EVENT_CTA_LOG = "eCtaLog"
EVENT_CTA_STRATEGY = "eCtaStrategy"
EVENT_CTA_STOPORDER = "eCtaStopOrder"
ORDER_CTA2VT = {
CtaOrderType.BUY: (Direction.LONG,
Offset.OPEN),
CtaOrderType.SELL: (Direction.SHORT,
Offset.CLOSE),
CtaOrderType.SHORT: (Direction.SHORT,
Offset.OPEN),
CtaOrderType.COVER: (Direction.LONG,
Offset.CLOSE),
CtaOrderType.BUY: (Direction.LONG, Offset.OPEN),
CtaOrderType.SELL: (Direction.SHORT, Offset.CLOSE),
CtaOrderType.SHORT: (Direction.SHORT, Offset.OPEN),
CtaOrderType.COVER: (Direction.LONG, Offset.CLOSE),
}

View File

@ -15,7 +15,7 @@ from vnpy.trader.object import (
CancelRequest,
SubscribeRequest,
LogData,
TickData
TickData,
)
from vnpy.trader.event import EVENT_TICK, EVENT_ORDER, EVENT_TRADE
from vnpy.trader.constant import Direction, Offset, Exchange, PriceType, Interval
@ -31,36 +31,32 @@ from .base import (
ORDER_CTA2VT,
EVENT_CTA_LOG,
EVENT_CTA_STRATEGY,
EVENT_CTA_STOPORDER
EVENT_CTA_STOPORDER,
)
class CtaEngine(BaseEngine):
""""""
engine_type = EngineType.LIVE # live trading engine
engine_type = EngineType.LIVE # live trading engine
filename = "CtaStrategy.vt"
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super(CtaEngine,
self).__init__(main_engine,
event_engine,
"CtaStrategy")
super(CtaEngine, self).__init__(main_engine, event_engine, "CtaStrategy")
self.setting_file = None # setting file object
self.setting_file = None # setting file object
self.classes = {} # class_name: stategy_class
self.strategies = {} # strategy_name: strategy
self.classes = {} # class_name: stategy_class
self.strategies = {} # strategy_name: strategy
self.symbol_strategy_map = defaultdict(list) # vt_symbol: strategy list
self.orderid_strategy_map = {} # vt_orderid: strategy
self.strategy_orderid_map = defaultdict(
set
) # strategy_name: orderid list
self.symbol_strategy_map = defaultdict(list) # vt_symbol: strategy list
self.orderid_strategy_map = {} # vt_orderid: strategy
self.strategy_orderid_map = defaultdict(set) # strategy_name: orderid list
self.stop_order_count = 0 # for generating stop_orderid
self.stop_orders = {} # stop_orderid: stop_order
self.stop_order_count = 0 # for generating stop_orderid
self.stop_orders = {} # stop_orderid: stop_order
def init_engine(self):
"""
@ -131,12 +127,10 @@ class CtaEngine(BaseEngine):
continue
long_triggered = (
so.direction == Direction.LONG
and tick.last_price >= stop_order.price
so.direction == Direction.LONG and tick.last_price >= stop_order.price
)
short_triggered = (
so.direction == Direction.SHORT
and tick.last_price <= stop_order.price
so.direction == Direction.SHORT and tick.last_price <= stop_order.price
)
if long_triggered or short_triggered:
@ -157,10 +151,7 @@ class CtaEngine(BaseEngine):
price = tick.bid_price_5
vt_orderid = self.send_limit_order(
strategy,
stop_order.order_type,
price,
stop_order.volume
strategy, stop_order.order_type, price, stop_order.volume
)
# Update stop order status if placed successfully
@ -177,17 +168,15 @@ class CtaEngine(BaseEngine):
stop_order.vt_orderid = vt_orderid
self.call_strategy_func(
strategy,
strategy.on_stop_order,
stop_order
strategy, strategy.on_stop_order, stop_order
)
def send_limit_order(
self,
strategy: CtaTemplate,
order_type: CtaOrderType,
price: float,
volume: float
self,
strategy: CtaTemplate,
order_type: CtaOrderType,
price: float,
volume: float,
):
"""
Send a new order.
@ -207,12 +196,9 @@ class CtaEngine(BaseEngine):
offset=offset,
price_type=PriceType.LIMIT,
price=price,
volume=volume
)
vt_orderid = self.main_engine.send_limit_order(
req,
contract.gateway_name
volume=volume,
)
vt_orderid = self.main_engine.send_limit_order(req, contract.gateway_name)
# Save relationship between orderid and strategy.
self.orderid_strategy_map[vt_orderid] = strategy
@ -223,11 +209,11 @@ class CtaEngine(BaseEngine):
return vt_orderid
def send_stop_order(
self,
strategy: CtaTemplate,
order_type: CtaOrderType,
price: float,
volume: float
self,
strategy: CtaTemplate,
order_type: CtaOrderType,
price: float,
volume: float,
):
"""
Send a new order.
@ -243,7 +229,7 @@ class CtaEngine(BaseEngine):
price=price,
volume=volume,
stop_orderid=stop_orderid,
strategy_name=strategy.strategy_name
strategy_name=strategy.strategy_name,
)
self.stop_orders[stop_orderid] = stop_order
@ -289,12 +275,12 @@ class CtaEngine(BaseEngine):
self.call_strategy_func(strategy, strategy.on_stop_order, stop_order)
def send_order(
self,
strategy: CtaTemplate,
order_type: CtaOrderType,
price: float,
volume: float,
stop: bool
self,
strategy: CtaTemplate,
order_type: CtaOrderType,
price: float,
volume: float,
stop: bool,
):
"""
"""
@ -327,11 +313,7 @@ class CtaEngine(BaseEngine):
return self.engine_type
def load_bar(
self,
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable
self, vt_symbol: str, days: int, interval: Interval, callback: Callable
):
""""""
pass
@ -341,10 +323,7 @@ class CtaEngine(BaseEngine):
pass
def call_strategy_func(
self,
strategy: CtaTemplate,
func: Callable,
params: Any = None
self, strategy: CtaTemplate, func: Callable, params: Any = None
):
"""
Call function of a strategy and catch any exception raised.
@ -362,11 +341,7 @@ class CtaEngine(BaseEngine):
self.write_log(msg, strategy)
def add_strategy(
self,
class_name: str,
strategy_name: str,
vt_symbol: str,
setting: dict
self, class_name: str, strategy_name: str, vt_symbol: str, setting: dict
):
"""
Add a new strategy.
@ -462,29 +437,18 @@ class CtaEngine(BaseEngine):
Load strategy class from source code.
"""
path1 = Path(__file__).parent.joinpath("strategies")
self.load_strategy_class_from_folder(
path1,
"vnpy.app.cta_strategy.strategies"
)
self.load_strategy_class_from_folder(path1, "vnpy.app.cta_strategy.strategies")
path2 = Path.cwd().joinpath("strategies")
self.load_strategy_class_from_folder(path2, "strategies")
def load_strategy_class_from_folder(
self,
path: Path,
module_name: str = ""
):
def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
"""
Load strategy class from certain folder.
"""
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
module_name = ".".join(
[module_name,
filename.replace(".py",
"")]
)
module_name = ".".join([module_name, filename.replace(".py", "")])
self.load_strategy_class_from_module(module_name)
def load_strategy_class_from_module(self, module_name: str):
@ -566,7 +530,7 @@ class CtaEngine(BaseEngine):
strategy.__class__.__name__,
strategy_name,
strategy.vt_symbol,
setting
setting,
)
self.setting_file.sync()

View File

@ -16,8 +16,6 @@ class DoubleMaStrategy(CtaTemplate):
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
""""""
super(DoubleMaStrategy,
self).__init__(cta_engine,
strategy_name,
vt_symbol,
setting)
super(DoubleMaStrategy, self).__init__(
cta_engine, strategy_name, vt_symbol, setting
)

View File

@ -1,7 +1,7 @@
""""""
from abc import ABC
from typing import Any
from typing import Any, Callable
from vnpy.trader.object import TickData, OrderData, TradeData, BarData
from vnpy.trader.constant import Interval
@ -17,11 +17,7 @@ class CtaTemplate(ABC):
variables = []
def __init__(
self,
cta_engine: Any,
strategy_name: str,
vt_symbol: str,
setting: dict
self, cta_engine: Any, strategy_name: str, vt_symbol: str, setting: dict
):
""""""
self.cta_engine = cta_engine
@ -84,7 +80,7 @@ class CtaTemplate(ABC):
"class_name": self.__class__.__name__,
"author": self.author,
"parameters": self.get_parameters(),
"variables": self.get_variables()
"variables": self.get_variables(),
}
return strategy_data
@ -155,11 +151,7 @@ class CtaTemplate(ABC):
return self.send_order(CtaOrderType.COVER, price, volume, stop)
def send_order(
self,
order_type: CtaOrderType,
price: float,
volume: float,
stop: bool = False
self, order_type: CtaOrderType, price: float, volume: float, stop: bool = False
):
"""
Send a new order.
@ -191,14 +183,14 @@ class CtaTemplate(ABC):
return self.cta_engine.get_engine_type()
def load_bar(
self,
days: int,
interval: Interval = Interval.MINUTE,
callback=self.on_bar
self, days: int, interval: Interval = Interval.MINUTE, callback: Callable = None
):
"""
Load historical bar data for initializing strategy.
"""
if not callback:
callback = self.on_bar
self.cta_engine.load_bar(self.vt_symbol, days, interval, callback)
def load_tick(self, days: int):

View File

@ -11,6 +11,7 @@ from ..base import APP_NAME, EVENT_CTA_LOG, EVENT_CTA_STOPORDER, EVENT_CTA_STRAT
class CtaManager(QtWidgets.QWidget):
""""""
signal_log = QtCore.pyqtSignal(Event)
signal_strategy = QtCore.pyqtSignal(Event)
@ -61,10 +62,7 @@ class CtaManager(QtWidgets.QWidget):
self.log_monitor = LogMonitor(self.main_engine, self.event_engine)
self.stop_order_monitor = StopOrderMonitor(
self.main_engine,
self.event_engine
)
self.stop_order_monitor = StopOrderMonitor(self.main_engine, self.event_engine)
# Set layout
hbox1 = QtWidgets.QHBoxLayout()
@ -88,18 +86,13 @@ class CtaManager(QtWidgets.QWidget):
def update_class_combo(self):
""""""
self.class_combo.addItems(
self.cta_engine.get_all_strategy_class_names()
)
self.class_combo.addItems(self.cta_engine.get_all_strategy_class_names())
def register_event(self):
""""""
self.signal_strategy.connect(self.process_strategy_event)
self.event_engine.register(
EVENT_CTA_STRATEGY,
self.signal_strategy.emit
)
self.event_engine.register(EVENT_CTA_STRATEGY, self.signal_strategy.emit)
def process_strategy_event(self, event):
"""
@ -136,12 +129,7 @@ class CtaManager(QtWidgets.QWidget):
vt_symbol = setting.pop("vt_symbol")
strategy_name = setting.pop("strategy_name")
self.cta_engine.add_strategy(
class_name,
strategy_name,
vt_symbol,
setting
)
self.cta_engine.add_strategy(class_name, strategy_name, vt_symbol, setting)
def show(self):
""""""
@ -153,12 +141,7 @@ class StrategyManager(QtWidgets.QFrame):
Manager for a strategy
"""
def __init__(
self,
cta_manager: CtaManager,
cta_engine: CtaEngine,
data: dict
):
def __init__(self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict):
""""""
super(StrategyManager, self).__init__()
@ -277,9 +260,7 @@ class DataMonitor(QtWidgets.QTableWidget):
self.setHorizontalHeaderLabels(labels)
self.setRowCount(1)
self.verticalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.Stretch
)
self.verticalHeader().setSectionResizeMode(QtWidgets.QHeaderView.Stretch)
self.verticalHeader().setVisible(False)
self.setEditTriggers(self.NoEditTriggers)
@ -320,51 +301,20 @@ class StopOrderMonitor(BaseMonitor):
"""
Monitor for local stop order.
"""
event_type = EVENT_CTA_STOPORDER
data_key = "stop_orderid"
sorting = True
headers = {
"stop_orderid": {
"display": "停止委托号",
"cell": BaseCell,
"update": False
},
"vt_orderid": {
"display": "限价委托号",
"cell": BaseCell,
"update": True
},
"vt_symbol": {
"display": "代码",
"cell": BaseCell,
"update": False
},
"order_type": {
"display": "类型",
"cell": EnumCell,
"update": False
},
"price": {
"display": "价格",
"cell": BaseCell,
"update": False
},
"volume": {
"display": "数量",
"cell": BaseCell,
"update": True
},
"status": {
"display": "状态",
"cell": EnumCell,
"update": True
},
"strategy": {
"display": "策略名",
"cell": StrategyCell,
"update": True
}
"stop_orderid": {"display": "停止委托号", "cell": BaseCell, "update": False},
"vt_orderid": {"display": "限价委托号", "cell": BaseCell, "update": True},
"vt_symbol": {"display": "代码", "cell": BaseCell, "update": False},
"order_type": {"display": "类型", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": True},
"status": {"display": "状态", "cell": EnumCell, "update": True},
"strategy": {"display": "策略名", "cell": StrategyCell, "update": True},
}
def init_ui(self):
@ -373,30 +323,21 @@ class StopOrderMonitor(BaseMonitor):
"""
super(StopOrderMonitor, self).init_ui()
self.horizontalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.Stretch
)
self.horizontalHeader().setSectionResizeMode(QtWidgets.QHeaderView.Stretch)
class LogMonitor(BaseMonitor):
"""
Monitor for log data.
"""
event_type = EVENT_CTA_LOG
data_key = ""
sorting = False
headers = {
"time": {
"display": "时间",
"cell": TimeCell,
"update": False
},
"msg": {
"display": "信息",
"cell": MsgCell,
"update": False
}
"time": {"display": "时间", "cell": TimeCell, "update": False},
"msg": {"display": "信息", "cell": MsgCell, "update": False},
}
def init_ui(self):
@ -405,10 +346,7 @@ class LogMonitor(BaseMonitor):
"""
super(LogMonitor, self).init_ui()
self.horizontalHeader().setSectionResizeMode(
1,
QtWidgets.QHeaderView.Stretch
)
self.horizontalHeader().setSectionResizeMode(1, QtWidgets.QHeaderView.Stretch)
def insert_new_row(self, data):
"""
@ -423,12 +361,7 @@ class SettingEditor(QtWidgets.QDialog):
For creating new strategy and editing strategy parameters.
"""
def __init__(
self,
parameters: dict,
strategy_name: str = "",
class_name: str = ""
):
def __init__(self, parameters: dict, strategy_name: str = "", class_name: str = ""):
""""""
super(SettingEditor, self).__init__()

View File

@ -31,7 +31,7 @@ from vnpy.trader.object import (
TradeData,
PositionData,
AccountData,
ContractData
ContractData,
)
from vnpy.trader.constant import Direction, Status, PriceType, Exchange, Product
@ -46,7 +46,7 @@ STATUS_BITMEX2VT = {
"Partially filled": Status.PARTTRADED,
"Filled": Status.ALLTRADED,
"Canceled": Status.CANCELLED,
"Rejected": Status.REJECTED
"Rejected": Status.REJECTED,
}
DIRECTION_VT2BITMEX = {Direction.LONG: "Buy", Direction.SHORT: "Sell"}
@ -64,10 +64,9 @@ class BitmexGateway(BaseGateway):
"key": "",
"secret": "",
"session": 3,
"server": ["REAL",
"TESTNET"],
"server": ["REAL", "TESTNET"],
"proxy_host": "127.0.0.1",
"proxy_port": 1080
"proxy_port": 1080,
}
def __init__(self, event_engine):
@ -86,14 +85,7 @@ class BitmexGateway(BaseGateway):
proxy_host = setting["proxy_host"]
proxy_port = setting["proxy_port"]
self.rest_api.connect(
key,
secret,
session,
server,
proxy_host,
proxy_port
)
self.rest_api.connect(key, secret, session, server, proxy_host, proxy_port)
self.ws_api.connect(key, secret, server, proxy_host, proxy_port)
@ -138,7 +130,7 @@ class BitmexRestApi(RestClient):
self.key = ""
self.secret = ""
self.order_count = 1000000
self.order_count = 1_000_000
self.connect_time = 0
def sign(self, request):
@ -161,9 +153,7 @@ class BitmexRestApi(RestClient):
msg = request.method + "/api/v1" + path + str(expires) + request.data
signature = hmac.new(
self.secret,
msg.encode(),
digestmod=hashlib.sha256
self.secret, msg.encode(), digestmod=hashlib.sha256
).hexdigest()
# Add headers
@ -172,20 +162,20 @@ class BitmexRestApi(RestClient):
"Accept": "application/json",
"api-key": self.key,
"api-expires": str(expires),
"api-signature": signature
"api-signature": signature,
}
request.headers = headers
return request
def connect(
self,
key: str,
secret: str,
session: int,
server: str,
proxy_host: str,
proxy_port: int
self,
key: str,
secret: str,
session: int,
server: str,
proxy_host: str,
proxy_port: int,
):
"""
Initialize connection to REST server.
@ -193,9 +183,9 @@ class BitmexRestApi(RestClient):
self.key = key
self.secret = secret.encode()
self.connect_time = int(
datetime.now().strftime("%y%m%d%H%M%S")
) * self.order_count
self.connect_time = (
int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count
)
if server == "REAL":
self.init(REST_HOST, proxy_host, proxy_port)
@ -204,7 +194,7 @@ class BitmexRestApi(RestClient):
self.start(session)
self.gateway.write_log(u"REST API启动成功")
self.gateway.write_log("REST API启动成功")
def send_order(self, req: SubscribeRequest):
""""""
@ -217,7 +207,7 @@ class BitmexRestApi(RestClient):
"ordType": PRICETYPE_VT2BITMEX[req.price_type],
"price": req.price,
"orderQty": int(req.volume),
"clOrdID": orderid
"clOrdID": orderid,
}
# Only add price for limit order.
@ -268,11 +258,7 @@ class BitmexRestApi(RestClient):
self.gateway.write_log(msg)
def on_send_order_error(
self,
exception_type: type,
exception_value: Exception,
tb,
request: Request
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback when sending order caused exception.
@ -290,11 +276,7 @@ class BitmexRestApi(RestClient):
pass
def on_cancel_order_error(
self,
exception_type: type,
exception_value: Exception,
tb,
request: Request
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback when cancelling order failed on server.
@ -315,11 +297,7 @@ class BitmexRestApi(RestClient):
self.gateway.write_log(msg)
def on_error(
self,
exception_type: type,
exception_value: Exception,
tb,
request: Request
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback to handler request exception.
@ -328,10 +306,7 @@ class BitmexRestApi(RestClient):
self.gateway.write_log(msg)
sys.stderr.write(
self.exception_detail(exception_type,
exception_value,
tb,
request)
self.exception_detail(exception_type, exception_value, tb, request)
)
@ -355,7 +330,7 @@ class BitmexWebsocketApi(WebsocketClient):
"order": self.on_order,
"position": self.on_position,
"margin": self.on_account,
"instrument": self.on_contract
"instrument": self.on_contract,
}
self.ticks = {}
@ -364,12 +339,7 @@ class BitmexWebsocketApi(WebsocketClient):
self.trades = set()
def connect(
self,
key: str,
secret: str,
server: str,
proxy_host: str,
proxy_port: int
self, key: str, secret: str, server: str, proxy_host: str, proxy_port: int
):
""""""
self.key = key
@ -391,23 +361,23 @@ class BitmexWebsocketApi(WebsocketClient):
exchange=req.exchange,
name=req.symbol,
datetime=datetime.now(),
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.ticks[req.symbol] = tick
def on_connected(self):
""""""
self.gateway.write_log(u"Websocket API连接成功")
self.gateway.write_log("Websocket API连接成功")
self.authenticate()
def on_disconnected(self):
""""""
self.gateway.write_log(u"Websocket API连接断开")
self.gateway.write_log("Websocket API连接断开")
def on_packet(self, packet: dict):
""""""
if "error" in packet:
self.gateway.write_log(u"Websocket API报错%s" % packet["error"])
self.gateway.write_log("Websocket API报错%s" % packet["error"])
if "not valid" in packet["error"]:
self.active = False
@ -418,7 +388,7 @@ class BitmexWebsocketApi(WebsocketClient):
if success:
if req["op"] == "authKey":
self.gateway.write_log(u"Websocket API验证授权成功")
self.gateway.write_log("Websocket API验证授权成功")
self.subscribe_topic()
elif "table" in packet:
@ -436,11 +406,7 @@ class BitmexWebsocketApi(WebsocketClient):
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
self.gateway.write_log(msg)
sys.stderr.write(
self.exception_detail(exception_type,
exception_value,
tb)
)
sys.stderr.write(self.exception_detail(exception_type, exception_value, tb))
def authenticate(self):
"""
@ -451,9 +417,7 @@ class BitmexWebsocketApi(WebsocketClient):
path = "/realtime"
msg = method + path + str(expires)
signature = hmac.new(
self.secret,
msg.encode(),
digestmod=hashlib.sha256
self.secret, msg.encode(), digestmod=hashlib.sha256
).hexdigest()
req = {"op": "authKey", "args": [self.key, expires, signature]}
@ -464,8 +428,7 @@ class BitmexWebsocketApi(WebsocketClient):
Subscribe to all private topics.
"""
req = {
"op":
"subscribe",
"op": "subscribe",
"args": [
"instrument",
"trade",
@ -473,8 +436,8 @@ class BitmexWebsocketApi(WebsocketClient):
"execution",
"order",
"position",
"margin"
]
"margin",
],
}
self.send_packet(req)
@ -486,10 +449,7 @@ class BitmexWebsocketApi(WebsocketClient):
return
tick.last_price = d["price"]
tick.datetime = datetime.strptime(
d["timestamp"],
"%Y-%m-%dT%H:%M:%S.%fZ"
)
tick.datetime = datetime.strptime(d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
self.gateway.on_tick(copy(tick))
def on_depth(self, d):
@ -509,10 +469,7 @@ class BitmexWebsocketApi(WebsocketClient):
tick.__setattr__("ask_price_%s" % (n + 1), price)
tick.__setattr__("ask_volume_%s" % (n + 1), volume)
tick.datetime = datetime.strptime(
d["timestamp"],
"%Y-%m-%dT%H:%M:%S.%fZ"
)
tick.datetime = datetime.strptime(d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
self.gateway.on_tick(copy(tick))
def on_trade(self, d):
@ -540,7 +497,7 @@ class BitmexWebsocketApi(WebsocketClient):
price=d["lastPx"],
volume=d["lastQty"],
time=d["timestamp"][11:19],
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.gateway.on_trade(trade)
@ -568,7 +525,7 @@ class BitmexWebsocketApi(WebsocketClient):
price=d["price"],
volume=d["orderQty"],
time=d["timestamp"][11:19],
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.orders[sysid] = order
@ -584,7 +541,7 @@ class BitmexWebsocketApi(WebsocketClient):
exchange=Exchange.BITMEX,
direction=Direction.NET,
volume=d["currentQty"],
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.gateway.on_position(position)
@ -594,10 +551,7 @@ class BitmexWebsocketApi(WebsocketClient):
accountid = str(d["account"])
account = self.accounts.get(accountid, None)
if not account:
account = AccountData(
accountid=accountid,
gateway_name=self.gateway_name
)
account = AccountData(accountid=accountid, gateway_name=self.gateway_name)
self.accounts[accountid] = account
account.balance = d.get("marginBalance", account.balance)
@ -621,7 +575,7 @@ class BitmexWebsocketApi(WebsocketClient):
product=Product.FUTURES,
pricetick=d["tickSize"],
size=d["lotSize"],
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.gateway.on_contract(contract)

View File

@ -22,7 +22,7 @@ from futu import (
StockQuoteHandlerBase,
OrderBookHandlerBase,
TradeOrderHandlerBase,
TradeDealHandlerBase
TradeDealHandlerBase,
)
from vnpy.trader.gateway import BaseGateway
@ -36,14 +36,14 @@ from vnpy.trader.object import (
AccountData,
SubscribeRequest,
OrderRequest,
CancelRequest
CancelRequest,
)
from vnpy.trader.event import EVENT_TIMER
EXCHANGE_VT2FUTU = {
Exchange.SMART: "US",
Exchange.SEHK: "HK",
Exchange.HKFE: "HK_FUTURE"
Exchange.HKFE: "HK_FUTURE",
}
EXCHANGE_FUTU2VT = {v: k for k, v in EXCHANGE_VT2FUTU.items()}
@ -52,7 +52,7 @@ PRODUCT_VT2FUTU = {
Product.INDEX: "IDX",
Product.ETF: "ETF",
Product.WARRANT: "WARRANT",
Product.BOND: "BOND"
Product.BOND: "BOND",
}
DIRECTION_VT2FUTU = {Direction.LONG: TrdSide.BUY, Direction.SHORT: TrdSide.SELL}
@ -79,10 +79,8 @@ class FutuGateway(BaseGateway):
"password": "",
"host": "127.0.0.1",
"port": 11111,
"market": ["HK",
"US"],
"env": [TrdEnv.REAL,
TrdEnv.SIMULATE]
"market": ["HK", "US"],
"env": [TrdEnv.REAL, TrdEnv.SIMULATE],
}
def __init__(self, event_engine):
@ -126,7 +124,7 @@ class FutuGateway(BaseGateway):
"""
Query all data necessary.
"""
sleep(2.0) # Wait 2 seconds till connection completed.
sleep(2.0) # Wait 2 seconds till connection completed.
self.query_contract()
self.query_trade()
@ -235,7 +233,7 @@ class FutuGateway(BaseGateway):
def send_order(self, req):
""""""
side = DIRECTION_VT2FUTU[req.direction]
price_type = OrderType.NORMAL # Only limit order is supported.
price_type = OrderType.NORMAL # Only limit order is supported.
# Set price adjustment mode to inside adjustment.
if req.direction is Direction.LONG:
@ -251,7 +249,7 @@ class FutuGateway(BaseGateway):
side,
price_type,
trd_env=self.env,
adjust_limit=adjust_limit
adjust_limit=adjust_limit,
)
if code:
@ -268,11 +266,8 @@ class FutuGateway(BaseGateway):
def cancel_order(self, req):
""""""
code, data = self.trade_ctx.modify_order(
ModifyOrderOp.CANCEL,
req.orderid,
0,
0,
trd_env=self.env)
ModifyOrderOp.CANCEL, req.orderid, 0, 0, trd_env=self.env
)
if code:
self.write_log(f"撤单失败:{data}")
@ -295,7 +290,7 @@ class FutuGateway(BaseGateway):
product=product,
size=1,
pricetick=0.001,
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.on_contract(contract)
self.contracts[contract.vt_symbol] = contract
@ -314,11 +309,8 @@ class FutuGateway(BaseGateway):
account = AccountData(
accountid=f"{self.gateway_name}_{self.market}",
balance=float(row["total_assets"]),
frozen=(
float(row["total_assets"]) -
float(row["avl_withdrawal_cash"])
),
gateway_name=self.gateway_name
frozen=(float(row["total_assets"]) - float(row["avl_withdrawal_cash"])),
gateway_name=self.gateway_name,
)
self.on_account(account)
@ -340,7 +332,7 @@ class FutuGateway(BaseGateway):
frozen=(float(row["qty"]) - float(row["can_sell_qty"])),
price=float(row["pl_val"]),
pnl=float(row["cost_price"]),
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.on_position(pos)
@ -386,7 +378,7 @@ class FutuGateway(BaseGateway):
symbol=symbol,
exchange=exchange,
datetime=datetime.now(),
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.ticks[code] = tick
@ -405,10 +397,7 @@ class FutuGateway(BaseGateway):
date = row["data_date"].replace("-", "")
time = row["data_time"]
tick.datetime = datetime.strptime(
f"{date} {time}",
"%Y%m%d %H:%M:%S"
)
tick.datetime = datetime.strptime(f"{date} {time}", "%Y%m%d %H:%M:%S")
tick.open_price = row["open_price"]
tick.high_price = row["high_price"]
tick.low_price = row["low_price"]
@ -462,7 +451,7 @@ class FutuGateway(BaseGateway):
traded=float(row["dealt_qty"]),
status=STATUS_FUTU2VT[row["order_status"]],
time=row["create_time"].split(" ")[-1],
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.on_order(order)
@ -487,7 +476,7 @@ class FutuGateway(BaseGateway):
price=float(row["price"]),
volume=float(row["qty"]),
time=row["create_time"].split(" ")[-1],
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.on_trade(trade)

View File

@ -27,7 +27,7 @@ from vnpy.trader.object import (
AccountData,
SubscribeRequest,
OrderRequest,
CancelRequest
CancelRequest,
)
from vnpy.trader.constant import (
Product,
@ -36,7 +36,7 @@ from vnpy.trader.constant import (
Exchange,
Currency,
Status,
OptionType
OptionType,
)
PRICETYPE_VT2IB = {PriceType.LIMIT: "LMT", PriceType.MARKET: "MKT"}
@ -55,7 +55,7 @@ EXCHANGE_VT2IB = {
Exchange.CME: "CME",
Exchange.ICE: "ICE",
Exchange.SEHK: "SEHK",
Exchange.HKFE: "HKFE"
Exchange.HKFE: "HKFE",
}
EXCHANGE_IB2VT = {v: k for k, v in EXCHANGE_VT2IB.items()}
@ -64,7 +64,7 @@ STATUS_IB2VT = {
"Filled": Status.ALLTRADED,
"Cancelled": Status.CANCELLED,
"PendingSubmit": Status.SUBMITTING,
"PreSubmitted": Status.NOTTRADED
"PreSubmitted": Status.NOTTRADED,
}
PRODUCT_VT2IB = {
@ -72,7 +72,7 @@ PRODUCT_VT2IB = {
Product.FOREX: "CASH",
Product.SPOT: "CMDTY",
Product.OPTION: "OPT",
Product.FUTURES: "FUT"
Product.FUTURES: "FUT",
}
PRODUCT_IB2VT = {v: k for k, v in PRODUCT_VT2IB.items()}
@ -91,7 +91,7 @@ TICKFIELD_IB2VT = {
7: "low_price",
8: "volume",
9: "pre_close",
14: "open_price"
14: "open_price",
}
ACCOUNTFIELD_IB2VT = {
@ -182,21 +182,21 @@ class IbApi(EWrapper):
self.client = IbClient(self)
self.thread = Thread(target=self.client.run)
def connectAck(self): # pylint: disable=invalid-name
def connectAck(self): # pylint: disable=invalid-name
"""
Callback when connection is established.
"""
self.status = True
self.gateway.write_log("IB TWS连接成功")
def connectionClosed(self): # pylint: disable=invalid-name
def connectionClosed(self): # pylint: disable=invalid-name
"""
Callback when connection is closed.
"""
self.status = False
self.gateway.write_log("IB TWS连接断开")
def nextValidId(self, orderId: int): # pylint: disable=invalid-name
def nextValidId(self, orderId: int): # pylint: disable=invalid-name
"""
Callback of next valid orderid.
"""
@ -204,7 +204,7 @@ class IbApi(EWrapper):
self.orderid = orderId
def currentTime(self, time: int): # pylint: disable=invalid-name
def currentTime(self, time: int): # pylint: disable=invalid-name
"""
Callback of current server time of IB.
"""
@ -216,7 +216,9 @@ class IbApi(EWrapper):
msg = f"服务器时间: {time_string}"
self.gateway.write_log(msg)
def error(self, reqId: TickerId, errorCode: int, errorString: str): # pylint: disable=invalid-name
def error(
self, reqId: TickerId, errorCode: int, errorString: str
): # pylint: disable=invalid-name
"""
Callback of error caused by specific request.
"""
@ -226,11 +228,7 @@ class IbApi(EWrapper):
self.gateway.write_log(msg)
def tickPrice( # pylint: disable=invalid-name
self,
reqId: TickerId,
tickType: TickType,
price: float,
attrib: TickAttrib
self, reqId: TickerId, tickType: TickType, price: float, attrib: TickAttrib
):
"""
Callback of tick price update.
@ -257,7 +255,9 @@ class IbApi(EWrapper):
tick.datetime = datetime.now()
self.gateway.on_tick(copy(tick))
def tickSize(self, reqId: TickerId, tickType: TickType, size: int): # pylint: disable=invalid-name
def tickSize(
self, reqId: TickerId, tickType: TickType, size: int
): # pylint: disable=invalid-name
"""
Callback of tick volume update.
"""
@ -272,7 +272,9 @@ class IbApi(EWrapper):
self.gateway.on_tick(copy(tick))
def tickString(self, reqId: TickerId, tickType: TickType, value: str): # pylint: disable=invalid-name
def tickString(
self, reqId: TickerId, tickType: TickType, value: str
): # pylint: disable=invalid-name
"""
Callback of tick string update.
"""
@ -287,36 +289,35 @@ class IbApi(EWrapper):
self.gateway.on_tick(copy(tick))
def orderStatus( # pylint: disable=invalid-name
self,
orderId: OrderId,
status: str,
filled: float,
remaining: float,
avgFillPrice: float,
permId: int,
parentId: int,
lastFillPrice: float,
clientId: int,
whyHeld: str,
mktCapPrice: float
self,
orderId: OrderId,
status: str,
filled: float,
remaining: float,
avgFillPrice: float,
permId: int,
parentId: int,
lastFillPrice: float,
clientId: int,
whyHeld: str,
mktCapPrice: float,
):
"""
Callback of order status update.
"""
super(IbApi,
self).orderStatus(
orderId,
status,
filled,
remaining,
avgFillPrice,
permId,
parentId,
lastFillPrice,
clientId,
whyHeld,
mktCapPrice
)
super(IbApi, self).orderStatus(
orderId,
status,
filled,
remaining,
avgFillPrice,
permId,
parentId,
lastFillPrice,
clientId,
whyHeld,
mktCapPrice,
)
orderid = str(orderId)
order = self.orders.get(orderid, None)
@ -326,11 +327,11 @@ class IbApi(EWrapper):
self.gateway.on_order(copy(order))
def openOrder( # pylint: disable=invalid-name
self,
orderId: OrderId,
ib_contract: Contract,
ib_order: Order,
orderState: OrderState
self,
orderId: OrderId,
ib_contract: Contract,
ib_order: Order,
orderState: OrderState,
):
"""
Callback when opening new order.
@ -340,26 +341,19 @@ class IbApi(EWrapper):
orderid = str(orderId)
order = OrderData(
symbol=ib_contract.conId,
exchange=EXCHANGE_IB2VT.get(
ib_contract.exchange,
ib_contract.exchange
),
exchange=EXCHANGE_IB2VT.get(ib_contract.exchange, ib_contract.exchange),
orderid=orderid,
direction=DIRECTION_IB2VT[ib_order.action],
price=ib_order.lmtPrice,
volume=ib_order.totalQuantity,
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.orders[orderid] = order
self.gateway.on_order(copy(order))
def updateAccountValue( # pylint: disable=invalid-name
self,
key: str,
val: str,
currency: str,
accountName: str
self, key: str, val: str, currency: str, accountName: str
):
"""
Callback of account update.
@ -372,54 +366,49 @@ class IbApi(EWrapper):
accountid = f"{accountName}.{currency}"
account = self.accounts.get(accountid, None)
if not account:
account = AccountData(
accountid=accountid,
gateway_name=self.gateway_name
)
account = AccountData(accountid=accountid, gateway_name=self.gateway_name)
self.accounts[accountid] = account
name = ACCOUNTFIELD_IB2VT[key]
setattr(account, name, float(val))
def updatePortfolio( # pylint: disable=invalid-name
self,
contract: Contract,
position: float,
marketPrice: float,
marketValue: float,
averageCost: float,
unrealizedPNL: float,
realizedPNL: float,
accountName: str
self,
contract: Contract,
position: float,
marketPrice: float,
marketValue: float,
averageCost: float,
unrealizedPNL: float,
realizedPNL: float,
accountName: str,
):
"""
Callback of position update.
"""
super(IbApi,
self).updatePortfolio(
contract,
position,
marketPrice,
marketValue,
averageCost,
unrealizedPNL,
realizedPNL,
accountName
)
super(IbApi, self).updatePortfolio(
contract,
position,
marketPrice,
marketValue,
averageCost,
unrealizedPNL,
realizedPNL,
accountName,
)
pos = PositionData(
symbol=contract.conId,
exchange=EXCHANGE_IB2VT.get(contract.exchange,
contract.exchange),
exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange),
direction=DIRECTION_NET,
volume=position,
price=averageCost,
pnl=unrealizedPNL,
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.gateway.on_position(pos)
def updateAccountTime(self, timeStamp: str): # pylint: disable=invalid-name
def updateAccountTime(self, timeStamp: str): # pylint: disable=invalid-name
"""
Callback of account update time.
"""
@ -427,7 +416,9 @@ class IbApi(EWrapper):
for account in self.accounts.values():
self.gateway.on_account(copy(account))
def contractDetails(self, reqId: int, contractDetails: ContractDetails): # pylint: disable=invalid-name
def contractDetails(
self, reqId: int, contractDetails: ContractDetails
): # pylint: disable=invalid-name
"""
Callback of contract data update.
"""
@ -443,20 +434,21 @@ class IbApi(EWrapper):
contract = ContractData(
symbol=ib_symbol,
exchange=EXCHANGE_IB2VT.get(ib_exchange,
ib_exchange),
exchange=EXCHANGE_IB2VT.get(ib_exchange, ib_exchange),
name=contractDetails.longName,
product=PRODUCT_IB2VT[ib_product],
size=ib_size,
pricetick=contractDetails.minTick,
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.gateway.on_contract(contract)
self.contracts[contract.vt_symbol] = contract
def execDetails(self, reqId: int, contract: Contract, execution: Execution): # pylint: disable=invalid-name
def execDetails(
self, reqId: int, contract: Contract, execution: Execution
): # pylint: disable=invalid-name
"""
Callback of trade data update.
"""
@ -465,21 +457,19 @@ class IbApi(EWrapper):
today_date = datetime.now().strftime("%Y%m%d")
trade = TradeData(
symbol=contract.conId,
exchange=EXCHANGE_IB2VT.get(contract.exchange,
contract.exchange),
exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange),
orderid=str(execution.orderId),
tradeid=str(execution.execId),
direction=DIRECTION_IB2VT[execution.side],
price=execution.price,
volume=execution.shares,
time=datetime.strptime(execution.time,
"%Y%m%d %H:%M:%S"),
gateway_name=self.gateway_name
time=datetime.strptime(execution.time, "%Y%m%d %H:%M:%S"),
gateway_name=self.gateway_name,
)
self.gateway.on_trade(trade)
def managedAccounts(self, accountsList: str): # pylint: disable=invalid-name
def managedAccounts(self, accountsList: str): # pylint: disable=invalid-name
"""
Callback of all sub accountid.
"""
@ -497,11 +487,7 @@ class IbApi(EWrapper):
self.clientid = setting["clientid"]
self.client.connect(
setting["host"],
setting["port"],
setting["clientid"]
)
self.client.connect(setting["host"], setting["port"], setting["clientid"])
self.thread.start()
@ -544,7 +530,7 @@ class IbApi(EWrapper):
symbol=req.symbol,
exchange=req.exchange,
datetime=datetime.now(),
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
self.ticks[self.reqid] = tick
self.tick_exchange[self.reqid] = req.exchange

View File

@ -7,10 +7,11 @@ class BaseApp(ABC):
"""
Absstract class for app.
"""
app_name = "" # Unique name used for creating engine and widget
app_module = "" # App module string used in import_module
app_path = "" # Absolute path of app folder
display_name = "" # Name for display on the menu.
engine_class = None # App engine class
widget_name = "" # Class name of app widget
icon_name = "" # Icon file name of app widget
app_name = "" # Unique name used for creating engine and widget
app_module = "" # App module string used in import_module
app_path = "" # Absolute path of app folder
display_name = "" # Name for display on the menu.
engine_class = None # App engine class
widget_name = "" # Class name of app widget
icon_name = "" # Icon file name of app widget

View File

@ -9,6 +9,7 @@ class Direction(Enum):
"""
Direction of order/trade/position.
"""
LONG = ""
SHORT = ""
NET = ""
@ -18,6 +19,7 @@ class Offset(Enum):
"""
Offset of order/trade.
"""
NONE = ""
OPEN = ""
CLOSE = ""
@ -29,6 +31,7 @@ class Status(Enum):
"""
Order status.
"""
SUBMITTING = "提交中"
NOTTRADED = "未成交"
PARTTRADED = "部分成交"
@ -41,6 +44,7 @@ class Product(Enum):
"""
Product class.
"""
EQUITY = "股票"
FUTURES = "期货"
OPTION = "期权"
@ -56,6 +60,7 @@ class PriceType(Enum):
"""
Order price type.
"""
LIMIT = "限价"
MARKET = "市价"
FAK = "FAK"
@ -66,6 +71,7 @@ class OptionType(Enum):
"""
Option type.
"""
CALL = "看涨期权"
PUT = "看跌期权"
@ -74,6 +80,7 @@ class Exchange(Enum):
"""
Exchange.
"""
# Chinese
CFFEX = "CFFEX"
SHFE = "SHFE"
@ -102,6 +109,7 @@ class Currency(Enum):
"""
Currency.
"""
USD = "USD"
HKD = "HKD"
CNY = "CNY"

View File

@ -6,7 +6,7 @@ from peewee import (
CharField,
DateTimeField,
FloatField,
IntegerField
IntegerField,
)
from .utility import get_temp_path
@ -23,6 +23,7 @@ class DbBarData(Model):
Index is defined unique with vt_symbol, interval and datetime.
"""
symbol = CharField()
exchange = CharField()
datetime = DateTimeField()
@ -39,7 +40,7 @@ class DbBarData(Model):
class Meta:
database = DB
indexes = ((('vt_symbol', 'interval', 'datetime'), True),)
indexes = ((("vt_symbol", "interval", "datetime"), True),)
@staticmethod
def from_bar(bar: BarData):
@ -76,7 +77,7 @@ class DbBarData(Model):
high_price=high_price,
low_price=low_price,
close_price=close_price,
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
return bar
@ -87,6 +88,7 @@ class DbTickData(Model):
Index is defined unique with vt_symbol, interval and datetime.
"""
symbol = CharField()
exchange = CharField()
datetime = DateTimeField()
@ -132,7 +134,7 @@ class DbTickData(Model):
class Meta:
database = DB
indexes = ((('vt_symbol', 'datetime'), True),)
indexes = ((("vt_symbol", "datetime"), True),)
@staticmethod
def from_tick(tick: TickData):
@ -208,7 +210,7 @@ class DbTickData(Model):
ask_price_1=self.ask_price_1,
bid_volume_1=self.bid_volume_1,
ask_volume_1=self.ask_volume_1,
gateway_name=self.gateway_name
gateway_name=self.gateway_name,
)
if self.bid_price_2:

View File

@ -19,7 +19,7 @@ from .event import (
EVENT_TRADE,
EVENT_POSITION,
EVENT_ACCOUNT,
EVENT_CONTRACT
EVENT_CONTRACT,
)
from .object import LogData, SubscribeRequest, OrderRequest, CancelRequest
from .utility import Singleton, get_temp_path
@ -180,10 +180,7 @@ class BaseEngine(ABC):
"""
def __init__(
self,
main_engine: MainEngine,
event_engine: EventEngine,
engine_name: str
self, main_engine: MainEngine, event_engine: EventEngine, engine_name: str
):
""""""
self.main_engine = main_engine
@ -211,9 +208,7 @@ class LogEngine(BaseEngine):
self.level = SETTINGS["log.level"]
self.logger = logging.getLogger("VN Trader")
self.formatter = logging.Formatter(
"%(asctime)s %(levelname)s: %(message)s"
)
self.formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
self.add_null_handler()
@ -249,7 +244,7 @@ class LogEngine(BaseEngine):
filename = f"vt_{today_date}.log"
file_path = get_temp_path(filename)
file_handler = logging.FileHandler(file_path, mode='w', encoding='utf8')
file_handler = logging.FileHandler(file_path, mode="w", encoding="utf8")
file_handler.setLevel(self.level)
file_handler.setFormatter(self.formatter)
self.logger.addHandler(file_handler)
@ -421,7 +416,7 @@ class OmsEngine(BaseEngine):
"""
return list(self.contracts.values())
def get_all_active_orders(self, vt_symbol: str = ''):
def get_all_active_orders(self, vt_symbol: str = ""):
"""
Get all active orders by vt_symbol.
@ -431,7 +426,8 @@ class OmsEngine(BaseEngine):
return list(self.active_orders.values())
else:
active_orders = [
order for order in self.active_orders.values()
order
for order in self.active_orders.values()
if order.vt_symbol == vt_symbol
]
return active_orders
@ -476,12 +472,10 @@ class EmailEngine(BaseEngine):
try:
msg = self.queue.get(block=True, timeout=1)
with smtplib.SMTP_SSL(SETTINGS["email.server"],
SETTINGS["email.port"]) as smtp:
smtp.login(
SETTINGS["email.username"],
SETTINGS["email.password"]
)
with smtplib.SMTP_SSL(
SETTINGS["email.server"], SETTINGS["email.port"]
) as smtp:
smtp.login(SETTINGS["email.username"], SETTINGS["email.password"])
smtp.send_message(msg)
except Empty:
pass

View File

@ -4,10 +4,10 @@ Event type string used in VN Trader.
from vnpy.event import EVENT_TIMER
EVENT_TICK = 'eTick.'
EVENT_TRADE = 'eTrade.'
EVENT_ORDER = 'eOrder.'
EVENT_POSITION = 'ePosition.'
EVENT_ACCOUNT = 'eAccount.'
EVENT_CONTRACT = 'eContract.'
EVENT_LOG = 'eLog'
EVENT_TICK = "eTick."
EVENT_TRADE = "eTrade."
EVENT_ORDER = "eOrder."
EVENT_POSITION = "ePosition."
EVENT_ACCOUNT = "eAccount."
EVENT_CONTRACT = "eContract."
EVENT_LOG = "eLog"

View File

@ -14,7 +14,7 @@ from .event import (
EVENT_ACCOUNT,
EVENT_POSITION,
EVENT_LOG,
EVENT_CONTRACT
EVENT_CONTRACT,
)
from .object import (
TickData,
@ -26,7 +26,7 @@ from .object import (
ContractData,
SubscribeRequest,
OrderRequest,
CancelRequest
CancelRequest,
)

View File

@ -17,6 +17,7 @@ class BaseData:
Any data object needs a gateway_name as source or
destination and should inherit base data.
"""
gateway_name: str
@ -28,6 +29,7 @@ class TickData(BaseData):
* orderbook snapshot
* intraday market statistics.
"""
symbol: str
exchange: Exchange
datetime: datetime
@ -78,6 +80,7 @@ class BarData(BaseData):
"""
Candlestick bar data of a certain trading period.
"""
symbol: str
exchange: Exchange
datetime: datetime
@ -100,6 +103,7 @@ class OrderData(BaseData):
Order data contains information for tracking lastest status
of a specific order.
"""
symbol: str
exchange: Exchange
orderid: str
@ -131,9 +135,7 @@ class OrderData(BaseData):
Create cancel request object from order.
"""
req = CancelRequest(
orderid=self.orderid,
symbol=self.symbol,
exchange=self.exchange
orderid=self.orderid, symbol=self.symbol, exchange=self.exchange
)
return req
@ -144,6 +146,7 @@ class TradeData(BaseData):
Trade data contains information of a fill of an order. One order
can have several trade fills.
"""
symbol: str
exchange: Exchange
orderid: str
@ -167,6 +170,7 @@ class PositionData(BaseData):
"""
Positon data is used for tracking each individual position holding.
"""
symbol: str
exchange: Exchange
direction: Direction
@ -188,6 +192,7 @@ class AccountData(BaseData):
Account data contains information about balance, frozen and
available.
"""
accountid: str
balance: float = 0
@ -204,6 +209,7 @@ class LogData(BaseData):
"""
Log data is used for recording log messages on GUI or in log files.
"""
msg: str
level: int = INFO
@ -217,6 +223,7 @@ class ContractData(BaseData):
"""
Contract data contains basic information about each contract traded.
"""
symbol: str
exchange: Exchange
name: str
@ -225,8 +232,8 @@ class ContractData(BaseData):
pricetick: float
option_strike: float = 0
option_underlying: str = '' # vt_symbol of underlying contract
option_type: str = ''
option_underlying: str = "" # vt_symbol of underlying contract
option_type: str = ""
option_expiry: datetime = None
def __post_init__(self):
@ -239,6 +246,7 @@ class SubscribeRequest:
"""
Request sending to specific gateway for subscribing tick data update.
"""
symbol: str
exchange: Exchange
@ -252,6 +260,7 @@ class OrderRequest:
"""
Request sending to specific gateway for creating a new order.
"""
symbol: str
exchange: Exchange
direction: Direction
@ -276,7 +285,7 @@ class OrderRequest:
offset=self.offset,
price=self.price,
volume=self.volume,
gateway_name=gateway_name
gateway_name=gateway_name,
)
return order
@ -286,6 +295,7 @@ class CancelRequest:
"""
Request sending to specific gateway for canceling an existing order.
"""
orderid: str
symbol: str
exchange: Exchange

View File

@ -16,5 +16,5 @@ SETTINGS = {
"email.username": "",
"email.password": "",
"email.sender": "",
"email.receiver": ""
"email.receiver": "",
}

View File

@ -14,13 +14,8 @@ from ..utility import get_icon_path
def excepthook(exctype, value, tb):
"""异常捕捉钩子"""
msg = ''.join(traceback.format_exception(exctype, value, tb))
QtWidgets.QMessageBox.critical(
None,
u'Exception',
msg,
QtWidgets.QMessageBox.Ok
)
msg = "".join(traceback.format_exception(exctype, value, tb))
QtWidgets.QMessageBox.critical(None, u"Exception", msg, QtWidgets.QMessageBox.Ok)
def create_qapp():
@ -38,9 +33,7 @@ def create_qapp():
icon = QtGui.QIcon(get_icon_path(__file__, "vnpy.ico"))
qapp.setWindowIcon(icon)
if 'Windows' in platform.uname():
ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(
'VN Trader'
)
if "Windows" in platform.uname():
ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("VN Trader")
return qapp

View File

@ -22,7 +22,7 @@ from .widget import (
TradingWidget,
ActiveOrderMonitor,
ContractManager,
AboutDialog
AboutDialog,
)
@ -54,14 +54,30 @@ class MainWindow(QtWidgets.QMainWindow):
def init_dock(self):
""""""
trading_widget, trading_dock = self.create_dock(TradingWidget, "交易", QtCore.Qt.LeftDockWidgetArea)
tick_widget, tick_dock = self.create_dock(TickMonitor, "行情", QtCore.Qt.RightDockWidgetArea)
order_widget, order_dock = self.create_dock(OrderMonitor, "委托", QtCore.Qt.RightDockWidgetArea)
active_widget, active_dock = self.create_dock(ActiveOrderMonitor, "活动", QtCore.Qt.RightDockWidgetArea)
trade_widget, trade_dock = self.create_dock(TradeMonitor, "成交", QtCore.Qt.RightDockWidgetArea)
log_widget, log_dock = self.create_dock(LogMonitor, "日志", QtCore.Qt.BottomDockWidgetArea)
account_widget, account_dock = self.create_dock(AccountMonitor, "资金", QtCore.Qt.BottomDockWidgetArea)
position_widget, position_dock = self.create_dock(PositionMonitor, "持仓", QtCore.Qt.BottomDockWidgetArea)
trading_widget, trading_dock = self.create_dock(
TradingWidget, "交易", QtCore.Qt.LeftDockWidgetArea
)
tick_widget, tick_dock = self.create_dock(
TickMonitor, "行情", QtCore.Qt.RightDockWidgetArea
)
order_widget, order_dock = self.create_dock(
OrderMonitor, "委托", QtCore.Qt.RightDockWidgetArea
)
active_widget, active_dock = self.create_dock(
ActiveOrderMonitor, "活动", QtCore.Qt.RightDockWidgetArea
)
trade_widget, trade_dock = self.create_dock(
TradeMonitor, "成交", QtCore.Qt.RightDockWidgetArea
)
log_widget, log_dock = self.create_dock(
LogMonitor, "日志", QtCore.Qt.BottomDockWidgetArea
)
account_widget, account_dock = self.create_dock(
AccountMonitor, "资金", QtCore.Qt.BottomDockWidgetArea
)
position_widget, position_dock = self.create_dock(
PositionMonitor, "持仓", QtCore.Qt.BottomDockWidgetArea
)
self.tabifyDockWidget(active_dock, order_dock)
@ -96,52 +112,31 @@ class MainWindow(QtWidgets.QMainWindow):
func = partial(self.open_widget, widget_class, app.app_name)
icon_path = str(app.app_path.joinpath("ui", app.icon_name))
self.add_menu_action(
app_menu,
f"打开{app.display_name}",
icon_path,
func
)
self.add_menu_action(app_menu, f"打开{app.display_name}", icon_path, func)
# Help menu
self.add_menu_action(
help_menu,
"查询合约",
"contract.ico",
partial(self.open_widget,
ContractManager,
"contract")
partial(self.open_widget, ContractManager, "contract"),
)
self.add_menu_action(
help_menu,
"还原窗口",
"restore.ico",
self.restore_window_setting
help_menu, "还原窗口", "restore.ico", self.restore_window_setting
)
self.add_menu_action(
help_menu,
"测试邮件",
"email.ico",
self.send_test_email
)
self.add_menu_action(help_menu, "测试邮件", "email.ico", self.send_test_email)
self.add_menu_action(
help_menu,
"关于",
"about.ico",
partial(self.open_widget,
AboutDialog,
"about")
partial(self.open_widget, AboutDialog, "about"),
)
def add_menu_action(
self,
menu: QtWidgets.QMenu,
action_name: str,
icon_name: str,
func: Callable
self, menu: QtWidgets.QMenu, action_name: str, icon_name: str, func: Callable
):
""""""
icon = QtGui.QIcon(get_icon_path(__file__, icon_name))
@ -152,12 +147,7 @@ class MainWindow(QtWidgets.QMainWindow):
menu.addAction(action)
def create_dock(
self,
widget_class: QtWidgets.QWidget,
name: str,
area: int
):
def create_dock(self, widget_class: QtWidgets.QWidget, name: str, area: int):
"""
Initialize a dock widget.
"""
@ -189,7 +179,7 @@ class MainWindow(QtWidgets.QMainWindow):
"退出",
"确认退出?",
QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
QtWidgets.QMessageBox.No
QtWidgets.QMessageBox.No,
)
if reply == QtWidgets.QMessageBox.Yes:

View File

@ -18,7 +18,7 @@ from ..event import (
EVENT_ACCOUNT,
EVENT_POSITION,
EVENT_CONTRACT,
EVENT_LOG
EVENT_LOG,
)
from ..object import SubscribeRequest, OrderRequest, CancelRequest
from ..utility import load_setting, save_setting
@ -299,22 +299,19 @@ class BaseMonitor(QtWidgets.QTableWidget):
"""
Resize all columns according to contents.
"""
self.horizontalHeader().resizeSections(
QtWidgets.QHeaderView.ResizeToContents
)
self.horizontalHeader().resizeSections(QtWidgets.QHeaderView.ResizeToContents)
def save_csv(self):
"""
Save table data into a csv file
"""
path, _ = QtWidgets.QFileDialog.getSaveFileName(self, "保存数据", "",
"CSV(*.csv)")
path, _ = QtWidgets.QFileDialog.getSaveFileName(self, "保存数据", "", "CSV(*.csv)")
if not path:
return
with open(path, "w") as f:
writer = csv.writer(f, lineterminator='\n')
writer = csv.writer(f, lineterminator="\n")
writer.writerow(self.headers.keys())
@ -339,81 +336,26 @@ class TickMonitor(BaseMonitor):
"""
Monitor for tick data.
"""
event_type = EVENT_TICK
data_key = "vt_symbol"
sorting = True
headers = {
"symbol": {
"display": "代码",
"cell": BaseCell,
"update": False
},
"exchange": {
"display": "交易所",
"cell": EnumCell,
"update": False
},
"name": {
"display": "名称",
"cell": BaseCell,
"update": True
},
"last_price": {
"display": "最新价",
"cell": BaseCell,
"update": True
},
"volume": {
"display": "成交量",
"cell": BaseCell,
"update": True
},
"open_price": {
"display": "开盘价",
"cell": BaseCell,
"update": True
},
"high_price": {
"display": "最高价",
"cell": BaseCell,
"update": True
},
"low_price": {
"display": "最低价",
"cell": BaseCell,
"update": True
},
"bid_price_1": {
"display": "买1价",
"cell": BidCell,
"update": True
},
"bid_volume_1": {
"display": "买1量",
"cell": BidCell,
"update": True
},
"ask_price_1": {
"display": "卖1价",
"cell": AskCell,
"update": True
},
"ask_volume_1": {
"display": "卖1量",
"cell": AskCell,
"update": True
},
"datetime": {
"display": "时间",
"cell": TimeCell,
"update": True
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
"symbol": {"display": "代码", "cell": BaseCell, "update": False},
"exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"name": {"display": "名称", "cell": BaseCell, "update": True},
"last_price": {"display": "最新价", "cell": BaseCell, "update": True},
"volume": {"display": "成交量", "cell": BaseCell, "update": True},
"open_price": {"display": "开盘价", "cell": BaseCell, "update": True},
"high_price": {"display": "最高价", "cell": BaseCell, "update": True},
"low_price": {"display": "最低价", "cell": BaseCell, "update": True},
"bid_price_1": {"display": "买1价", "cell": BidCell, "update": True},
"bid_volume_1": {"display": "买1量", "cell": BidCell, "update": True},
"ask_price_1": {"display": "卖1价", "cell": AskCell, "update": True},
"ask_volume_1": {"display": "卖1量", "cell": AskCell, "update": True},
"datetime": {"display": "时间", "cell": TimeCell, "update": True},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
}
@ -421,26 +363,15 @@ class LogMonitor(BaseMonitor):
"""
Monitor for log data.
"""
event_type = EVENT_LOG
data_key = ""
sorting = False
headers = {
"time": {
"display": "时间",
"cell": TimeCell,
"update": False
},
"msg": {
"display": "信息",
"cell": MsgCell,
"update": False
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
"time": {"display": "时间", "cell": TimeCell, "update": False},
"msg": {"display": "信息", "cell": MsgCell, "update": False},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
}
@ -448,61 +379,22 @@ class TradeMonitor(BaseMonitor):
"""
Monitor for trade data.
"""
event_type = EVENT_TRADE
data_key = ""
sorting = True
headers = {
"tradeid": {
"display": "成交号 ",
"cell": BaseCell,
"update": False
},
"orderid": {
"display": "委托号",
"cell": BaseCell,
"update": False
},
"symbol": {
"display": "代码",
"cell": BaseCell,
"update": False
},
"exchange": {
"display": "交易所",
"cell": EnumCell,
"update": False
},
"direction": {
"display": "方向",
"cell": DirectionCell,
"update": False
},
"offset": {
"display": "开平",
"cell": EnumCell,
"update": False
},
"price": {
"display": "价格",
"cell": BaseCell,
"update": False
},
"volume": {
"display": "数量",
"cell": BaseCell,
"update": False
},
"time": {
"display": "时间",
"cell": BaseCell,
"update": False
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
"tradeid": {"display": "成交号 ", "cell": BaseCell, "update": False},
"orderid": {"display": "委托号", "cell": BaseCell, "update": False},
"symbol": {"display": "代码", "cell": BaseCell, "update": False},
"exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"direction": {"display": "方向", "cell": DirectionCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": False},
"time": {"display": "时间", "cell": BaseCell, "update": False},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
}
@ -510,66 +402,23 @@ class OrderMonitor(BaseMonitor):
"""
Monitor for order data.
"""
event_type = EVENT_ORDER
data_key = "vt_orderid"
sorting = True
headers = {
"orderid": {
"display": "委托号",
"cell": BaseCell,
"update": False
},
"symbol": {
"display": "代码",
"cell": BaseCell,
"update": False
},
"exchange": {
"display": "交易所",
"cell": EnumCell,
"update": False
},
"direction": {
"display": "方向",
"cell": DirectionCell,
"update": False
},
"offset": {
"display": "开平",
"cell": EnumCell,
"update": False
},
"price": {
"display": "价格",
"cell": BaseCell,
"update": False
},
"volume": {
"display": "总数量",
"cell": BaseCell,
"update": True
},
"traded": {
"display": "已成交",
"cell": BaseCell,
"update": True
},
"status": {
"display": "状态",
"cell": EnumCell,
"update": True
},
"time": {
"display": "时间",
"cell": BaseCell,
"update": True
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
"orderid": {"display": "委托号", "cell": BaseCell, "update": False},
"symbol": {"display": "代码", "cell": BaseCell, "update": False},
"exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"direction": {"display": "方向", "cell": DirectionCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "总数量", "cell": BaseCell, "update": True},
"traded": {"display": "已成交", "cell": BaseCell, "update": True},
"status": {"display": "状态", "cell": EnumCell, "update": True},
"time": {"display": "时间", "cell": BaseCell, "update": True},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
}
def init_ui(self):
@ -594,51 +443,20 @@ class PositionMonitor(BaseMonitor):
"""
Monitor for position data.
"""
event_type = EVENT_POSITION
data_key = "vt_positionid"
sorting = True
headers = {
"symbol": {
"display": "代码",
"cell": BaseCell,
"update": False
},
"exchange": {
"display": "交易所",
"cell": EnumCell,
"update": False
},
"direction": {
"display": "方向",
"cell": DirectionCell,
"update": False
},
"volume": {
"display": "数量",
"cell": BaseCell,
"update": True
},
"frozen": {
"display": "冻结",
"cell": BaseCell,
"update": True
},
"price": {
"display": "均价",
"cell": BaseCell,
"update": False
},
"pnl": {
"display": "盈亏",
"cell": PnlCell,
"update": True
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
"symbol": {"display": "代码", "cell": BaseCell, "update": False},
"exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"direction": {"display": "方向", "cell": DirectionCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": True},
"frozen": {"display": "冻结", "cell": BaseCell, "update": True},
"price": {"display": "均价", "cell": BaseCell, "update": False},
"pnl": {"display": "盈亏", "cell": PnlCell, "update": True},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
}
@ -646,36 +464,17 @@ class AccountMonitor(BaseMonitor):
"""
Monitor for account data.
"""
event_type = EVENT_ACCOUNT
data_key = "vt_accountid"
sorting = True
headers = {
"accountid": {
"display": "账号",
"cell": BaseCell,
"update": False
},
"balance": {
"display": "余额",
"cell": BaseCell,
"update": True
},
"frozen": {
"display": "冻结",
"cell": BaseCell,
"update": True
},
"available": {
"display": "可用",
"cell": BaseCell,
"update": True
},
"gateway_name": {
"display": "接口",
"cell": BaseCell,
"update": False
}
"accountid": {"display": "账号", "cell": BaseCell, "update": False},
"balance": {"display": "余额", "cell": BaseCell, "update": True},
"frozen": {"display": "冻结", "cell": BaseCell, "update": True},
"available": {"display": "可用", "cell": BaseCell, "update": True},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
}
@ -701,9 +500,7 @@ class ConnectDialog(QtWidgets.QDialog):
self.setWindowTitle(f"连接{self.gateway_name}")
# Default setting provides field name, field data type and field default value.
default_setting = self.main_engine.get_default_setting(
self.gateway_name
)
default_setting = self.main_engine.get_default_setting(self.gateway_name)
# Saved setting provides field data used last time.
loaded_setting = load_setting(self.filename)
@ -732,7 +529,7 @@ class ConnectDialog(QtWidgets.QDialog):
form.addRow(f"{field_name} <{field_type.__name__}>", widget)
self.widgets[field_name] = (widget, field_type)
button = QtWidgets.QPushButton(u"连接")
button = QtWidgets.QPushButton("连接")
button.clicked.connect(self.connect)
form.addRow(button)
@ -792,18 +589,13 @@ class TradingWidget(QtWidgets.QWidget):
self.name_line.setReadOnly(True)
self.direction_combo = QtWidgets.QComboBox()
self.direction_combo.addItems(
[Direction.LONG.value,
Direction.SHORT.value]
)
self.direction_combo.addItems([Direction.LONG.value, Direction.SHORT.value])
self.offset_combo = QtWidgets.QComboBox()
self.offset_combo.addItems([offset.value for offset in Offset])
self.price_type_combo = QtWidgets.QComboBox()
self.price_type_combo.addItems(
[price_type.value for price_type in PriceType]
)
self.price_type_combo.addItems([price_type.value for price_type in PriceType])
double_validator = QtGui.QDoubleValidator()
double_validator.setBottom(0)
@ -846,26 +638,11 @@ class TradingWidget(QtWidgets.QWidget):
self.bp4_label = self.create_label(bid_color)
self.bp5_label = self.create_label(bid_color)
self.bv1_label = self.create_label(
bid_color,
alignment=QtCore.Qt.AlignRight
)
self.bv2_label = self.create_label(
bid_color,
alignment=QtCore.Qt.AlignRight
)
self.bv3_label = self.create_label(
bid_color,
alignment=QtCore.Qt.AlignRight
)
self.bv4_label = self.create_label(
bid_color,
alignment=QtCore.Qt.AlignRight
)
self.bv5_label = self.create_label(
bid_color,
alignment=QtCore.Qt.AlignRight
)
self.bv1_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight)
self.bv2_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight)
self.bv3_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight)
self.bv4_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight)
self.bv5_label = self.create_label(bid_color, alignment=QtCore.Qt.AlignRight)
self.ap1_label = self.create_label(ask_color)
self.ap2_label = self.create_label(ask_color)
@ -873,26 +650,11 @@ class TradingWidget(QtWidgets.QWidget):
self.ap4_label = self.create_label(ask_color)
self.ap5_label = self.create_label(ask_color)
self.av1_label = self.create_label(
ask_color,
alignment=QtCore.Qt.AlignRight
)
self.av2_label = self.create_label(
ask_color,
alignment=QtCore.Qt.AlignRight
)
self.av3_label = self.create_label(
ask_color,
alignment=QtCore.Qt.AlignRight
)
self.av4_label = self.create_label(
ask_color,
alignment=QtCore.Qt.AlignRight
)
self.av5_label = self.create_label(
ask_color,
alignment=QtCore.Qt.AlignRight
)
self.av1_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight)
self.av2_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight)
self.av3_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight)
self.av4_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight)
self.av5_label = self.create_label(ask_color, alignment=QtCore.Qt.AlignRight)
self.lp_label = self.create_label()
self.return_label = self.create_label(alignment=QtCore.Qt.AlignRight)
@ -916,11 +678,7 @@ class TradingWidget(QtWidgets.QWidget):
vbox.addLayout(form2)
self.setLayout(vbox)
def create_label(
self,
color: str = "",
alignment: int = QtCore.Qt.AlignLeft
):
def create_label(self, color: str = "", alignment: int = QtCore.Qt.AlignLeft):
"""
Create label with certain font color.
"""
@ -992,7 +750,7 @@ class TradingWidget(QtWidgets.QWidget):
contract = self.main_engine.get_contract(vt_symbol)
if not contract:
self.name_line.setText("")
gateway_name = (self.gateway_combo.currentText())
gateway_name = self.gateway_combo.currentText()
else:
self.name_line.setText(contract.name)
gateway_name = contract.gateway_name
@ -1067,7 +825,7 @@ class TradingWidget(QtWidgets.QWidget):
price_type=PriceType(str(self.price_type_combo.currentText())),
volume=volume,
price=price,
offset=Offset(str(self.offset_combo.currentText()))
offset=Offset(str(self.offset_combo.currentText())),
)
gateway_name = str(self.gateway_combo.currentText())
@ -1118,7 +876,7 @@ class ContractManager(QtWidgets.QWidget):
"product": "合约分类",
"size": "合约乘数",
"pricetick": "价格跳动",
"gateway_name": "交易接口"
"gateway_name": "交易接口",
}
def __init__(self, main_engine, event_engine):
@ -1171,8 +929,7 @@ class ContractManager(QtWidgets.QWidget):
all_contracts = self.main_engine.get_all_contracts()
if flt:
contracts = [
contract for contract in all_contracts
if flt in contract.vt_symbol
contract for contract in all_contracts if flt in contract.vt_symbol
]
else:
contracts = all_contracts

View File

@ -14,14 +14,13 @@ class Singleton(type):
__metaclass__ = Singleton
"""
_instances = {}
def __call__(cls, *args, **kwargs):
""""""
if cls not in cls._instances:
cls._instances[cls] = super(Singleton,
cls).__call__(*args,
**kwargs)
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
@ -39,7 +38,7 @@ def get_temp_path(filename: str):
Get path for temp file with filename.
"""
trader_path = get_trader_path()
temp_path = trader_path.joinpath('.vntrader')
temp_path = trader_path.joinpath(".vntrader")
if not temp_path.exists():
temp_path.mkdir()