[增强功能] PB gateway,xtp加固

This commit is contained in:
msincenselee 2020-06-04 23:09:28 +08:00
parent 246d4927f9
commit 99ae453c67
19 changed files with 1295 additions and 50 deletions

View File

@ -799,12 +799,19 @@ class CtaEngine(BaseEngine):
def get_price(self, vt_symbol: str): def get_price(self, vt_symbol: str):
"""查询合约的最新价格""" """查询合约的最新价格"""
price = self.main_engine.get_price(vt_symbol)
if price:
return price
tick = self.main_engine.get_tick(vt_symbol) tick = self.main_engine.get_tick(vt_symbol)
if tick: if tick:
return tick.last_price return tick.last_price
return None return None
def get_contract(self, vt_symbol):
return self.main_engine.get_contract(vt_symbol)
def get_account(self, vt_accountid: str = ""): def get_account(self, vt_accountid: str = ""):
""" 查询账号的资金""" """ 查询账号的资金"""
# 如果启动风控,则使用风控中的最大仓位 # 如果启动风控,则使用风控中的最大仓位
@ -1200,6 +1207,25 @@ class CtaEngine(BaseEngine):
self.write_error(u'保存策略{}数据异常:'.format(strategy_name, str(ex))) self.write_error(u'保存策略{}数据异常:'.format(strategy_name, str(ex)))
self.write_error(traceback.format_exc()) self.write_error(traceback.format_exc())
def get_strategy_snapshot(self, strategy_name):
"""实时获取策略的K线切片比较耗性能"""
strategy = self.strategies.get(strategy_name, None)
if strategy is None:
return None
try:
# 5.保存策略切片
snapshot = strategy.get_klines_snapshot()
if not snapshot:
self.write_log(f'{strategy_name}返回得K线切片数据为空')
return None
return snapshot
except Exception as ex:
self.write_error(u'获取策略{}切片数据异常:'.format(strategy_name, str(ex)))
self.write_error(traceback.format_exc())
return None
def save_strategy_snapshot(self, select_name: str = 'ALL'): def save_strategy_snapshot(self, select_name: str = 'ALL'):
""" """
保存策略K线切片数据 保存策略K线切片数据

View File

@ -601,7 +601,8 @@ class CtaFutureTemplate(CtaTemplate):
def init_policy(self): def init_policy(self):
self.write_log(u'init_policy(),初始化执行逻辑') self.write_log(u'init_policy(),初始化执行逻辑')
self.policy.load() if self.policy:
self.policy.load()
def init_position(self): def init_position(self):
""" """
@ -1356,11 +1357,13 @@ class CtaFutureTemplate(CtaTemplate):
return return
self.write_log(u'{} 当前 {}价格:{}' self.write_log(u'{} 当前 {}价格:{}'
.format(self.cur_datetime, self.vt_symbol, self.cur_price)) .format(self.cur_datetime, self.vt_symbol, self.cur_price))
if hasattr(self, 'policy'): if hasattr(self, 'policy'):
policy = getattr(self, 'policy') policy = getattr(self, 'policy')
op = getattr(policy, 'to_json', None) if policy:
if callable(op): op = getattr(policy, 'to_json', None)
self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False))) if callable(op):
self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False)))
def save_dist(self, dist_data): def save_dist(self, dist_data):
""" """

View File

@ -8,12 +8,14 @@ from vnpy.trader.ui.widget import (
TimeCell, TimeCell,
BaseMonitor BaseMonitor
) )
from vnpy.trader.ui.kline.ui_snapshot import UiSnapshot
from ..base import ( from ..base import (
APP_NAME, APP_NAME,
EVENT_CTA_LOG, EVENT_CTA_LOG,
EVENT_CTA_STOPORDER, EVENT_CTA_STOPORDER,
EVENT_CTA_STRATEGY EVENT_CTA_STRATEGY
) )
from ..engine import CtaEngine from ..engine import CtaEngine
@ -202,11 +204,11 @@ class StrategyManager(QtWidgets.QFrame):
reload_button = QtWidgets.QPushButton("重载") reload_button = QtWidgets.QPushButton("重载")
reload_button.clicked.connect(self.reload_strategy) reload_button.clicked.connect(self.reload_strategy)
save_button = QtWidgets.QPushButton("") save_button = QtWidgets.QPushButton("")
save_button.clicked.connect(self.save_strategy) save_button.clicked.connect(self.save_strategy)
snapshot_button = QtWidgets.QPushButton("切片") view_button = QtWidgets.QPushButton("K线")
snapshot_button.clicked.connect(self.save_snapshot) view_button.clicked.connect(self.view_strategy_snapshot)
strategy_name = self._data["strategy_name"] strategy_name = self._data["strategy_name"]
vt_symbol = self._data["vt_symbol"] vt_symbol = self._data["vt_symbol"]
@ -230,7 +232,7 @@ class StrategyManager(QtWidgets.QFrame):
hbox.addWidget(remove_button) hbox.addWidget(remove_button)
hbox.addWidget(reload_button) hbox.addWidget(reload_button)
hbox.addWidget(save_button) hbox.addWidget(save_button)
hbox.addWidget(snapshot_button) hbox.addWidget(view_button)
vbox = QtWidgets.QVBoxLayout() vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(label) vbox.addWidget(label)
@ -283,13 +285,18 @@ class StrategyManager(QtWidgets.QFrame):
self.cta_engine.reload_strategy(self.strategy_name) self.cta_engine.reload_strategy(self.strategy_name)
def save_strategy(self): def save_strategy(self):
"""保存K线缓存""" """保存策略缓存数据"""
self.cta_engine.save_strategy_data(self.strategy_name) self.cta_engine.save_strategy_data(self.strategy_name)
def save_snapshot(self):
""" 保存切片"""
self.cta_engine.save_strategy_snapshot(self.strategy_name) self.cta_engine.save_strategy_snapshot(self.strategy_name)
def view_strategy_snapshot(self):
"""实时查看策略切片"""
snapshot = self.cta_engine.get_strategy_snapshot(self.strategy_name)
if snapshot is None:
return
ui_snapshot = UiSnapshot()
ui_snapshot.show(snapshot_file="", d=snapshot)
class DataMonitor(QtWidgets.QTableWidget): class DataMonitor(QtWidgets.QTableWidget):
""" """
Table monitor for parameters and variables. Table monitor for parameters and variables.

View File

@ -13,7 +13,7 @@
cythonize -i demo_strategy.py cythonize -i demo_strategy.py
编译完成后Demo文件夹下会多出2个新的文件其中就有已加密的策略文件demo_strategy.cp37-win_amd64.pyd 编译完成后Demo文件夹下会多出2个新的文件其中就有已加密的策略文件demo_strategy.cp37-win_amd64.pyd
改名=> demo_strategy.pyd 改名=> demo_strategy.pyd

View File

@ -1108,14 +1108,23 @@ class CtaEngine(BaseEngine):
if len(setting) == 0: if len(setting) == 0:
strategies_setting = load_json(self.setting_filename) strategies_setting = load_json(self.setting_filename)
old_strategy_config = strategies_setting.get(strategy_name, {}) old_strategy_config = strategies_setting.get(strategy_name, {})
self.write_log(f'使用配置文件的配置:{old_strategy_config}')
else: else:
old_strategy_config = copy(self.strategy_setting[strategy_name]) old_strategy_config = copy(self.strategy_setting[strategy_name])
self.write_log(f'使用已经运行的配置:{old_strategy_config}')
class_name = old_strategy_config.get('class_name') class_name = old_strategy_config.get('class_name')
self.write_log(f'使用策略类名:{class_name}')
# 没有配置vt_symbol时使用配置文件/旧配置中的vt_symbol
if len(vt_symbol) == 0: if len(vt_symbol) == 0:
vt_symbol = old_strategy_config.get('vt_symbol') vt_symbol = old_strategy_config.get('vt_symbol')
self.write_log(f'使用配置文件/已运行配置的vt_symbol:{vt_symbol}')
# 没有新配置时,使用配置文件/旧配置中的setting
if len(setting) == 0: if len(setting) == 0:
setting = old_strategy_config.get('setting') setting = old_strategy_config.get('setting')
self.write_log(f'没有新策略参数,使用配置文件/旧配置中的setting:{setting}')
module_name = self.class_module_map[class_name] module_name = self.class_module_map[class_name]
# 重新load class module # 重新load class module
@ -1639,15 +1648,17 @@ class CtaEngine(BaseEngine):
compare_info = '' compare_info = ''
for vt_symbol in sorted(vt_symbols): for vt_symbol in sorted(vt_symbols):
# 发送不一致得结果 # 发送不一致得结果
symbol_pos = compare_pos.pop(vt_symbol) symbol_pos = compare_pos.pop(vt_symbol, None)
if symbol_pos is None:
continue
d_long = { d_long = {
'account_id': self.engine_config.get('account_id', '-'), 'account_id': self.engine_config.get('accountid', '-'),
'vt_symbol': vt_symbol, 'vt_symbol': vt_symbol,
'direction': Direction.LONG.value, 'direction': Direction.LONG.value,
'strategy_list': symbol_pos.get('多单策略', [])} 'strategy_list': symbol_pos.get('多单策略', [])}
d_short = { d_short = {
'account_id': self.engine_config.get('account_id', '-'), 'account_id': self.engine_config.get('accountid', '-'),
'vt_symbol': vt_symbol, 'vt_symbol': vt_symbol,
'direction': Direction.SHORT.value, 'direction': Direction.SHORT.value,
'strategy_list': symbol_pos.get('空单策略', [])} 'strategy_list': symbol_pos.get('空单策略', [])}
@ -1693,7 +1704,7 @@ class CtaEngine(BaseEngine):
# 不匹配输入到stdErr通道 # 不匹配输入到stdErr通道
if pos_compare_result != '': if pos_compare_result != '':
msg = u'账户{}持仓不匹配: {}' \ msg = u'账户{}持仓不匹配: {}' \
.format(self.engine_config.get('account_id', '-'), .format(self.engine_config.get('accountid', '-'),
pos_compare_result) pos_compare_result)
try: try:
from vnpy.trader.util_wechat import send_wx_msg from vnpy.trader.util_wechat import send_wx_msg

View File

@ -849,6 +849,7 @@ class CtaProTemplate(CtaTemplate):
for vt_orderid in list(self.active_orders.keys()): for vt_orderid in list(self.active_orders.keys()):
order_info = self.active_orders.get(vt_orderid) order_info = self.active_orders.get(vt_orderid)
order_grid = order_info.get('grid',None)
if order_info.get('status', None) in [Status.CANCELLED, Status.REJECTED]: if order_info.get('status', None) in [Status.CANCELLED, Status.REJECTED]:
self.active_orders.pop(vt_orderid, None) self.active_orders.pop(vt_orderid, None)
continue continue
@ -863,6 +864,11 @@ class CtaProTemplate(CtaTemplate):
order_info.update({'status': Status.CANCELLING}) order_info.update({'status': Status.CANCELLING})
else: else:
order_info.update({'status': Status.CANCELLED}) order_info.update({'status': Status.CANCELLED})
if order_grid:
if vt_orderid in order_grid.order_ids:
order_grid.order_ids.remove(vt_orderid)
if len(order_grid.order_ids) == 0:
order_grid.order_status = False
if len(self.active_orders) < 1: if len(self.active_orders) < 1:
self.entrust = 0 self.entrust = 0

View File

@ -645,8 +645,8 @@ class CtaGridTrade(CtaComponent):
if lots > 0: if lots > 0:
for i in range(0, lots, 1): for i in range(0, lots, 1):
# 做多,开仓价为下阻力线-网格高度*i平仓价为开仓价+止盈高度,开仓数量为缺省 # 做多,开仓价为下阻力线-网格高度*i平仓价为开仓价+止盈高度,开仓数量为缺省
open_price = int((down_line - self.grid_height * down_rate) / self.price_tick) * self.price_tick open_price = int((down_line - self.grid_height * down_rate * i) / self.price_tick) * self.price_tick
close_price = int((open_price + self.grid_win * down_rate) / self.price_tick) * self.price_tick close_price = int((open_price + self.grid_win * down_rate * i) / self.price_tick) * self.price_tick
grid = CtaGrid(direction=Direction.LONG, grid = CtaGrid(direction=Direction.LONG,
open_price=open_price, open_price=open_price,
@ -686,8 +686,8 @@ class CtaGridTrade(CtaComponent):
if lots > 0: if lots > 0:
# 做空,开仓价为上阻力线+网格高度*i平仓价为开仓价-止盈高度,开仓数量为缺省 # 做空,开仓价为上阻力线+网格高度*i平仓价为开仓价-止盈高度,开仓数量为缺省
for i in range(0, lots, 1): for i in range(0, lots, 1):
open_price = int((upper_line + self.grid_height * upper_rate) / self.price_tick) * self.price_tick open_price = int((upper_line + self.grid_height * upper_rate * i) / self.price_tick) * self.price_tick
close_price = int((open_price - self.grid_win * upper_rate) / self.price_tick) * self.price_tick close_price = int((open_price - self.grid_win * upper_rate * i) / self.price_tick) * self.price_tick
grid = CtaGrid(direction=Direction.SHORT, grid = CtaGrid(direction=Direction.SHORT,
open_price=open_price, open_price=open_price,

View File

@ -2935,7 +2935,7 @@ class CtaLineBar(object):
if self.bar_len < maxLen: if self.bar_len < maxLen:
return return
dif, dea, macd = ta.MACD(np.append(self.close_array[-maxLen:], [self.line_bar[-1].close]), dif, dea, macd = ta.MACD(np.append(self.close_array[-maxLen:], [self.line_bar[-1].close_price]),
fastperiod=self.para_macd_fast_len, fastperiod=self.para_macd_fast_len,
slowperiod=self.para_macd_slow_len, signalperiod=self.para_macd_signal_len) slowperiod=self.para_macd_slow_len, signalperiod=self.para_macd_signal_len)
@ -4008,7 +4008,7 @@ class CtaLineBar(object):
return return
# 3、获取前InputN周期(包含当前周期的K线 # 3、获取前InputN周期(包含当前周期的K线
last_bar_mid3 = (self.line_bar[-1].close_price + self.line_bar[-1].high_price + self.line_bar[-1].low_price) / 3 last_bar_mid3 = (self.line_bar[-1].close_price + self.line_bar[-1].high_price + self.line_bar[-1].low_price) / 3
bar_mid3_ema10 = ta.EMA(np.append(self.mid3_array[-ema_len * 3:], [last_bar_mid3]), ema_len)[-1] bar_mid3_ema10 = ta.EMA(np.append(self.mid3_array[-ema_len * 4:], [last_bar_mid3]), ema_len)[-1]
self._rt_yb = round(float(bar_mid3_ema10), self.round_n) self._rt_yb = round(float(bar_mid3_ema10), self.round_n)
@property @property

View File

@ -262,7 +262,6 @@ class TdxFutureData(object):
self.symbol_exchange_dict.update({tdx_symbol: Tdx_Vn_Exchange_Map.get(str(tdx_market_id))}) self.symbol_exchange_dict.update({tdx_symbol: Tdx_Vn_Exchange_Map.get(str(tdx_market_id))})
self.symbol_market_dict.update({tdx_symbol: tdx_market_id}) self.symbol_market_dict.update({tdx_symbol: tdx_market_id})
# ---------------------------------------------------------------------- # ----------------------------------------------------------------------
def get_bars(self, def get_bars(self,
symbol: str, symbol: str,

View File

@ -68,7 +68,8 @@ from vnpy.trader.utility import (
get_trading_date, get_trading_date,
get_underlying_symbol, get_underlying_symbol,
round_to, round_to,
BarGenerator BarGenerator,
print_dict
) )
from vnpy.trader.event import EVENT_TIMER from vnpy.trader.event import EVENT_TIMER
@ -222,6 +223,12 @@ class CtpGateway(BaseGateway):
self.combiner_conf_dict = c.get_config() self.combiner_conf_dict = c.get_config()
if len(self.combiner_conf_dict) > 0: if len(self.combiner_conf_dict) > 0:
self.write_log(u'加载的自定义价差/价比配置:{}'.format(self.combiner_conf_dict)) self.write_log(u'加载的自定义价差/价比配置:{}'.format(self.combiner_conf_dict))
contract_dict = c.get_contracts()
for vt_symbol, contract in contract_dict.items():
contract.gateway_name = self.gateway_name
self.on_contract(contract)
except Exception as ex: # noqa except Exception as ex: # noqa
pass pass
if not self.td_api: if not self.td_api:

View File

@ -53,19 +53,21 @@ from vnpy.trader.constant import (
) )
from vnpy.trader.utility import get_file_path from vnpy.trader.utility import get_file_path
# 委托方式映射
ORDERTYPE_VT2IB = { ORDERTYPE_VT2IB = {
OrderType.LIMIT: "LMT", OrderType.LIMIT: "LMT", # 限价单
OrderType.MARKET: "MKT", OrderType.MARKET: "MKT", # 市场价
OrderType.STOP: "STP" OrderType.STOP: "STP" # 停止价
} }
ORDERTYPE_IB2VT = {v: k for k, v in ORDERTYPE_VT2IB.items()} ORDERTYPE_IB2VT = {v: k for k, v in ORDERTYPE_VT2IB.items()}
# 买卖方向映射
DIRECTION_VT2IB = {Direction.LONG: "BUY", Direction.SHORT: "SELL"} DIRECTION_VT2IB = {Direction.LONG: "BUY", Direction.SHORT: "SELL"}
DIRECTION_IB2VT = {v: k for k, v in DIRECTION_VT2IB.items()} DIRECTION_IB2VT = {v: k for k, v in DIRECTION_VT2IB.items()}
DIRECTION_IB2VT["BOT"] = Direction.LONG DIRECTION_IB2VT["BOT"] = Direction.LONG
DIRECTION_IB2VT["SLD"] = Direction.SHORT DIRECTION_IB2VT["SLD"] = Direction.SHORT
# 交易所映射
EXCHANGE_VT2IB = { EXCHANGE_VT2IB = {
Exchange.SMART: "SMART", Exchange.SMART: "SMART",
Exchange.NYMEX: "NYMEX", Exchange.NYMEX: "NYMEX",
@ -75,10 +77,13 @@ EXCHANGE_VT2IB = {
Exchange.ICE: "ICE", Exchange.ICE: "ICE",
Exchange.SEHK: "SEHK", Exchange.SEHK: "SEHK",
Exchange.HKFE: "HKFE", Exchange.HKFE: "HKFE",
Exchange.CFE: "CFE" Exchange.CFE: "CFE",
Exchange.NYSE: "NYSE",
Exchange.NASDAQ: "NASDAQ"
} }
EXCHANGE_IB2VT = {v: k for k, v in EXCHANGE_VT2IB.items()} EXCHANGE_IB2VT = {v: k for k, v in EXCHANGE_VT2IB.items()}
# 状态映射
STATUS_IB2VT = { STATUS_IB2VT = {
"ApiPending": Status.SUBMITTING, "ApiPending": Status.SUBMITTING,
"PendingSubmit": Status.SUBMITTING, "PendingSubmit": Status.SUBMITTING,
@ -90,6 +95,7 @@ STATUS_IB2VT = {
"Inactive": Status.REJECTED, "Inactive": Status.REJECTED,
} }
# 品种类型映射
PRODUCT_IB2VT = { PRODUCT_IB2VT = {
"STK": Product.EQUITY, "STK": Product.EQUITY,
"CASH": Product.FOREX, "CASH": Product.FOREX,
@ -99,14 +105,17 @@ PRODUCT_IB2VT = {
"FOT": Product.OPTION "FOT": Product.OPTION
} }
# 期权映射
OPTION_VT2IB = {OptionType.CALL: "CALL", OptionType.PUT: "PUT"} OPTION_VT2IB = {OptionType.CALL: "CALL", OptionType.PUT: "PUT"}
# 币种映射
CURRENCY_VT2IB = { CURRENCY_VT2IB = {
Currency.USD: "USD", Currency.USD: "USD",
Currency.CNY: "CNY", Currency.CNY: "CNY",
Currency.HKD: "HKD", Currency.HKD: "HKD",
} }
# tick字段映射
TICKFIELD_IB2VT = { TICKFIELD_IB2VT = {
0: "bid_volume_1", 0: "bid_volume_1",
1: "bid_price_1", 1: "bid_price_1",
@ -121,6 +130,7 @@ TICKFIELD_IB2VT = {
14: "open_price", 14: "open_price",
} }
# 账号字段映射
ACCOUNTFIELD_IB2VT = { ACCOUNTFIELD_IB2VT = {
"NetLiquidationByCurrency": "balance", "NetLiquidationByCurrency": "balance",
"NetLiquidation": "balance", "NetLiquidation": "balance",
@ -129,12 +139,14 @@ ACCOUNTFIELD_IB2VT = {
"MaintMarginReq": "margin", "MaintMarginReq": "margin",
} }
# 时间周期映射
INTERVAL_VT2IB = { INTERVAL_VT2IB = {
Interval.MINUTE: "1 min", Interval.MINUTE: "1 min",
Interval.HOUR: "1 hour", Interval.HOUR: "1 hour",
Interval.DAILY: "1 day", Interval.DAILY: "1 day",
} }
# 合约连接符
JOIN_SYMBOL = "-" JOIN_SYMBOL = "-"
@ -150,9 +162,9 @@ class IbGateway(BaseGateway):
exchanges = list(EXCHANGE_VT2IB.keys()) exchanges = list(EXCHANGE_VT2IB.keys())
def __init__(self, event_engine): def __init__(self, event_engine, gateway_name='IB'):
"""""" """"""
super().__init__(event_engine, "IB") super().__init__(event_engine, gateway_name)
self.api = IbApi(self) self.api = IbApi(self)
@ -483,13 +495,14 @@ class IbApi(EWrapper):
self.gateway.write_log(msg) self.gateway.write_log(msg)
return return
ib_size = contract.multiplier try:
if not ib_size: ib_size = int(contract.multiplier)
except ValueError:
ib_size = 1 ib_size = 1
price = averageCost / ib_size price = averageCost / ib_size
pos = PositionData( pos = PositionData(
symbol=contract.conId, symbol=generate_symbol(contract),
exchange=exchange, exchange=exchange,
direction=Direction.NET, direction=Direction.NET,
volume=position, volume=position,
@ -807,7 +820,7 @@ class IbClient(EClient):
def generate_ib_contract(symbol: str, exchange: Exchange) -> Optional[Contract]: def generate_ib_contract(symbol: str, exchange: Exchange) -> Optional[Contract]:
"""""" """生成ib合约"""
try: try:
fields = symbol.split(JOIN_SYMBOL) fields = symbol.split(JOIN_SYMBOL)
@ -831,7 +844,7 @@ def generate_ib_contract(symbol: str, exchange: Exchange) -> Optional[Contract]:
def generate_symbol(ib_contract: Contract) -> str: def generate_symbol(ib_contract: Contract) -> str:
"""""" """生成合约代码"""
fields = [ib_contract.symbol] fields = [ib_contract.symbol]
if ib_contract.secType in ["FUT", "OPT", "FOP"]: if ib_contract.secType in ["FUT", "OPT", "FOP"]:

File diff suppressed because it is too large Load Diff

View File

@ -67,7 +67,8 @@ from vnpy.trader.utility import (
get_trading_date, get_trading_date,
get_underlying_symbol, get_underlying_symbol,
round_to, round_to,
BarGenerator BarGenerator,
print_dict
) )
from vnpy.trader.event import EVENT_TIMER from vnpy.trader.event import EVENT_TIMER
@ -214,6 +215,12 @@ class RohonGateway(BaseGateway):
self.combiner_conf_dict = c.get_config() self.combiner_conf_dict = c.get_config()
if len(self.combiner_conf_dict) > 0: if len(self.combiner_conf_dict) > 0:
self.write_log(u'加载的自定义价差/价比配置:{}'.format(self.combiner_conf_dict)) self.write_log(u'加载的自定义价差/价比配置:{}'.format(self.combiner_conf_dict))
contract_dict = c.get_contracts()
for vt_symbol, contract in contract_dict.items():
contract.gateway_name = self.gateway_name
self.on_contract(contract)
except Exception as ex: # noqa except Exception as ex: # noqa
pass pass

View File

@ -200,7 +200,7 @@ class XtpGateway(BaseGateway):
def process_timer_event(self, event) -> None: def process_timer_event(self, event) -> None:
"""""" """"""
self.count += 1 self.count += 1
if self.count < 2: if self.count < 5:
return return
self.count = 0 self.count = 0
@ -371,6 +371,8 @@ class XtpMdApi(MdApi):
min_volume=data["buy_qty_unit"], min_volume=data["buy_qty_unit"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
#if contract.symbol.startswith('1230'):
# self.gateway.write_log(msg=f'合约信息:{contract.__dict__}')
self.gateway.on_contract(contract) self.gateway.on_contract(contract)
# 更新最新价 # 更新最新价

View File

@ -43,6 +43,7 @@ class Status(Enum):
CANCELLED = "已撤销" CANCELLED = "已撤销"
CANCELLING = "撤销中" CANCELLING = "撤销中"
REJECTED = "拒单" REJECTED = "拒单"
UNKNOWN = "未知"
class Product(Enum): class Product(Enum):
@ -95,9 +96,12 @@ class Exchange(Enum):
SZSE = "SZSE" # Shenzhen Stock Exchange SZSE = "SZSE" # Shenzhen Stock Exchange
SGE = "SGE" # Shanghai Gold Exchange SGE = "SGE" # Shanghai Gold Exchange
WXE = "WXE" # Wuxi Steel Exchange WXE = "WXE" # Wuxi Steel Exchange
CFETS = "CFETS" # China Foreign Exchange Trade System
# Global # Global
SMART = "SMART" # Smart Router for US stocks SMART = "SMART" # Smart Router for US stocks
NYSE = "NYSE" # New York Stock Exchnage
NASDAQ = "NASDAQ" # Nasdaq Exchange
NYMEX = "NYMEX" # New York Mercantile Exchange NYMEX = "NYMEX" # New York Mercantile Exchange
COMEX = "COMEX" # a division of theNew York Mercantile Exchange COMEX = "COMEX" # a division of theNew York Mercantile Exchange
GLOBEX = "GLOBEX" # Globex of CME GLOBEX = "GLOBEX" # Globex of CME

View File

@ -57,10 +57,11 @@ class OffsetConverter:
def get_position_holding(self, vt_symbol: str, gateway_name: str = '') -> "PositionHolding": def get_position_holding(self, vt_symbol: str, gateway_name: str = '') -> "PositionHolding":
"""获取持仓信息""" """获取持仓信息"""
if len(gateway_name) == 0: if gateway_name is None or len(gateway_name) == 0:
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
if contract: if contract:
gateway_name = contract.gateway_name gateway_name = contract.gateway_name
k = f'{gateway_name}.{vt_symbol}' k = f'{gateway_name}.{vt_symbol}'
holding = self.holdings.get(k, None) holding = self.holdings.get(k, None)
if not holding: if not holding:

View File

@ -331,13 +331,15 @@ class LocalOrderManager:
Management tool to support use local order id for trading. Management tool to support use local order id for trading.
""" """
def __init__(self, gateway: BaseGateway, order_prefix: str = ""): def __init__(self, gateway: BaseGateway, order_prefix: str = "", order_rjust:int = 8):
"""""" """"""
self.gateway: BaseGateway = gateway self.gateway: BaseGateway = gateway
# For generating local orderid # For generating local orderid
self.order_prefix: str = order_prefix self.order_prefix: str = order_prefix
self.order_rjust: int = order_rjust
self.order_count: int = 0 self.order_count: int = 0
self.orders: Dict[str, OrderData] = {} # local_orderid: order self.orders: Dict[str, OrderData] = {} # local_orderid: order
# Map between local and system orderid # Map between local and system orderid
@ -362,7 +364,7 @@ class LocalOrderManager:
Generate a new local orderid. Generate a new local orderid.
""" """
self.order_count += 1 self.order_count += 1
local_orderid = self.order_prefix + str(self.order_count).rjust(8, "0") local_orderid = self.order_prefix + str(self.order_count).rjust(self.order_rjust, "0")
return local_orderid return local_orderid
def get_local_orderid(self, sys_orderid: str) -> str: def get_local_orderid(self, sys_orderid: str) -> str:
@ -421,8 +423,11 @@ class LocalOrderManager:
def get_order_with_local_orderid(self, local_orderid: str) -> OrderData: def get_order_with_local_orderid(self, local_orderid: str) -> OrderData:
"""""" """"""
order = self.orders[local_orderid] order = self.orders.get(local_orderid, None)
return copy(order) if order:
return copy(order)
else:
return None
def on_order(self, order: OrderData) -> None: def on_order(self, order: OrderData) -> None:
""" """

View File

@ -1041,7 +1041,7 @@ class ContractManager(QtWidgets.QWidget):
all_contracts = self.main_engine.get_all_contracts() all_contracts = self.main_engine.get_all_contracts()
if flt: if flt:
contracts = [ contracts = [
contract for contract in all_contracts if flt in contract.vt_symbol contract for contract in all_contracts if flt in contract.vt_symbol.lower()
] ]
else: else:
contracts = all_contracts contracts = all_contracts

View File

@ -344,7 +344,7 @@ def get_csv_last_dt(file_name, dt_index=0, dt_format='%Y-%m-%d %H:%M:%S', line_l
return None return None
return None return None
def append_data(file_name: str, dict_data: dict, field_names: list = []): def append_data(file_name: str, dict_data: dict, field_names: list = [], auto_header=True, encoding='utf8'):
""" """
添加数据到csv文件中 添加数据到csv文件中
:param file_name: csv的文件全路径 :param file_name: csv的文件全路径
@ -354,15 +354,16 @@ def append_data(file_name: str, dict_data: dict, field_names: list = []):
dict_fieldnames = sorted(list(dict_data.keys())) if len(field_names) == 0 else field_names dict_fieldnames = sorted(list(dict_data.keys())) if len(field_names) == 0 else field_names
try: try:
if not os.path.exists(file_name): if not os.path.exists(file_name): # or os.path.getsize(file_name) == 0:
print(u'create csv file:{}'.format(file_name)) print(u'create csv file:{}'.format(file_name))
with open(file_name, 'a', encoding='utf8', newline='\n') as csvWriteFile: with open(file_name, 'a', encoding='utf8', newline='\n') as csvWriteFile:
writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel') writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel')
print(u'write csv header:{}'.format(dict_fieldnames)) if auto_header:
writer.writeheader() print(u'write csv header:{}'.format(dict_fieldnames))
writer.writeheader()
writer.writerow(dict_data) writer.writerow(dict_data)
else: else:
with open(file_name, 'a', encoding='utf8', newline='\n') as csvWriteFile: with open(file_name, 'a', encoding=encoding, newline='\n') as csvWriteFile:
writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel', writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel',
extrasaction='ignore') extrasaction='ignore')
writer.writerow(dict_data) writer.writerow(dict_data)