[增强功能] PB gateway,xtp加固

This commit is contained in:
msincenselee 2020-06-04 23:09:28 +08:00
parent 246d4927f9
commit 99ae453c67
19 changed files with 1295 additions and 50 deletions

View File

@ -799,12 +799,19 @@ class CtaEngine(BaseEngine):
def get_price(self, vt_symbol: str):
"""查询合约的最新价格"""
price = self.main_engine.get_price(vt_symbol)
if price:
return price
tick = self.main_engine.get_tick(vt_symbol)
if tick:
return tick.last_price
return None
def get_contract(self, vt_symbol):
return self.main_engine.get_contract(vt_symbol)
def get_account(self, vt_accountid: str = ""):
""" 查询账号的资金"""
# 如果启动风控,则使用风控中的最大仓位
@ -1200,6 +1207,25 @@ class CtaEngine(BaseEngine):
self.write_error(u'保存策略{}数据异常:'.format(strategy_name, str(ex)))
self.write_error(traceback.format_exc())
def get_strategy_snapshot(self, strategy_name):
"""实时获取策略的K线切片比较耗性能"""
strategy = self.strategies.get(strategy_name, None)
if strategy is None:
return None
try:
# 5.保存策略切片
snapshot = strategy.get_klines_snapshot()
if not snapshot:
self.write_log(f'{strategy_name}返回得K线切片数据为空')
return None
return snapshot
except Exception as ex:
self.write_error(u'获取策略{}切片数据异常:'.format(strategy_name, str(ex)))
self.write_error(traceback.format_exc())
return None
def save_strategy_snapshot(self, select_name: str = 'ALL'):
"""
保存策略K线切片数据

View File

@ -601,7 +601,8 @@ class CtaFutureTemplate(CtaTemplate):
def init_policy(self):
self.write_log(u'init_policy(),初始化执行逻辑')
self.policy.load()
if self.policy:
self.policy.load()
def init_position(self):
"""
@ -1356,11 +1357,13 @@ class CtaFutureTemplate(CtaTemplate):
return
self.write_log(u'{} 当前 {}价格:{}'
.format(self.cur_datetime, self.vt_symbol, self.cur_price))
if hasattr(self, 'policy'):
policy = getattr(self, 'policy')
op = getattr(policy, 'to_json', None)
if callable(op):
self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False)))
if policy:
op = getattr(policy, 'to_json', None)
if callable(op):
self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False)))
def save_dist(self, dist_data):
"""

View File

@ -8,12 +8,14 @@ from vnpy.trader.ui.widget import (
TimeCell,
BaseMonitor
)
from vnpy.trader.ui.kline.ui_snapshot import UiSnapshot
from ..base import (
APP_NAME,
EVENT_CTA_LOG,
EVENT_CTA_STOPORDER,
EVENT_CTA_STRATEGY
)
from ..engine import CtaEngine
@ -202,11 +204,11 @@ class StrategyManager(QtWidgets.QFrame):
reload_button = QtWidgets.QPushButton("重载")
reload_button.clicked.connect(self.reload_strategy)
save_button = QtWidgets.QPushButton("")
save_button = QtWidgets.QPushButton("")
save_button.clicked.connect(self.save_strategy)
snapshot_button = QtWidgets.QPushButton("切片")
snapshot_button.clicked.connect(self.save_snapshot)
view_button = QtWidgets.QPushButton("K线")
view_button.clicked.connect(self.view_strategy_snapshot)
strategy_name = self._data["strategy_name"]
vt_symbol = self._data["vt_symbol"]
@ -230,7 +232,7 @@ class StrategyManager(QtWidgets.QFrame):
hbox.addWidget(remove_button)
hbox.addWidget(reload_button)
hbox.addWidget(save_button)
hbox.addWidget(snapshot_button)
hbox.addWidget(view_button)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(label)
@ -283,13 +285,18 @@ class StrategyManager(QtWidgets.QFrame):
self.cta_engine.reload_strategy(self.strategy_name)
def save_strategy(self):
"""保存K线缓存"""
"""保存策略缓存数据"""
self.cta_engine.save_strategy_data(self.strategy_name)
def save_snapshot(self):
""" 保存切片"""
self.cta_engine.save_strategy_snapshot(self.strategy_name)
def view_strategy_snapshot(self):
"""实时查看策略切片"""
snapshot = self.cta_engine.get_strategy_snapshot(self.strategy_name)
if snapshot is None:
return
ui_snapshot = UiSnapshot()
ui_snapshot.show(snapshot_file="", d=snapshot)
class DataMonitor(QtWidgets.QTableWidget):
"""
Table monitor for parameters and variables.

View File

@ -13,7 +13,7 @@
cythonize -i demo_strategy.py
编译完成后Demo文件夹下会多出2个新的文件其中就有已加密的策略文件demo_strategy.cp37-win_amd64.pyd
编译完成后Demo文件夹下会多出2个新的文件其中就有已加密的策略文件demo_strategy.cp37-win_amd64.pyd
改名=> demo_strategy.pyd

View File

@ -1108,14 +1108,23 @@ class CtaEngine(BaseEngine):
if len(setting) == 0:
strategies_setting = load_json(self.setting_filename)
old_strategy_config = strategies_setting.get(strategy_name, {})
self.write_log(f'使用配置文件的配置:{old_strategy_config}')
else:
old_strategy_config = copy(self.strategy_setting[strategy_name])
self.write_log(f'使用已经运行的配置:{old_strategy_config}')
class_name = old_strategy_config.get('class_name')
self.write_log(f'使用策略类名:{class_name}')
# 没有配置vt_symbol时使用配置文件/旧配置中的vt_symbol
if len(vt_symbol) == 0:
vt_symbol = old_strategy_config.get('vt_symbol')
self.write_log(f'使用配置文件/已运行配置的vt_symbol:{vt_symbol}')
# 没有新配置时,使用配置文件/旧配置中的setting
if len(setting) == 0:
setting = old_strategy_config.get('setting')
self.write_log(f'没有新策略参数,使用配置文件/旧配置中的setting:{setting}')
module_name = self.class_module_map[class_name]
# 重新load class module
@ -1639,15 +1648,17 @@ class CtaEngine(BaseEngine):
compare_info = ''
for vt_symbol in sorted(vt_symbols):
# 发送不一致得结果
symbol_pos = compare_pos.pop(vt_symbol)
symbol_pos = compare_pos.pop(vt_symbol, None)
if symbol_pos is None:
continue
d_long = {
'account_id': self.engine_config.get('account_id', '-'),
'account_id': self.engine_config.get('accountid', '-'),
'vt_symbol': vt_symbol,
'direction': Direction.LONG.value,
'strategy_list': symbol_pos.get('多单策略', [])}
d_short = {
'account_id': self.engine_config.get('account_id', '-'),
'account_id': self.engine_config.get('accountid', '-'),
'vt_symbol': vt_symbol,
'direction': Direction.SHORT.value,
'strategy_list': symbol_pos.get('空单策略', [])}
@ -1693,7 +1704,7 @@ class CtaEngine(BaseEngine):
# 不匹配输入到stdErr通道
if pos_compare_result != '':
msg = u'账户{}持仓不匹配: {}' \
.format(self.engine_config.get('account_id', '-'),
.format(self.engine_config.get('accountid', '-'),
pos_compare_result)
try:
from vnpy.trader.util_wechat import send_wx_msg

View File

@ -849,6 +849,7 @@ class CtaProTemplate(CtaTemplate):
for vt_orderid in list(self.active_orders.keys()):
order_info = self.active_orders.get(vt_orderid)
order_grid = order_info.get('grid',None)
if order_info.get('status', None) in [Status.CANCELLED, Status.REJECTED]:
self.active_orders.pop(vt_orderid, None)
continue
@ -863,6 +864,11 @@ class CtaProTemplate(CtaTemplate):
order_info.update({'status': Status.CANCELLING})
else:
order_info.update({'status': Status.CANCELLED})
if order_grid:
if vt_orderid in order_grid.order_ids:
order_grid.order_ids.remove(vt_orderid)
if len(order_grid.order_ids) == 0:
order_grid.order_status = False
if len(self.active_orders) < 1:
self.entrust = 0

View File

@ -645,8 +645,8 @@ class CtaGridTrade(CtaComponent):
if lots > 0:
for i in range(0, lots, 1):
# 做多,开仓价为下阻力线-网格高度*i平仓价为开仓价+止盈高度,开仓数量为缺省
open_price = int((down_line - self.grid_height * down_rate) / self.price_tick) * self.price_tick
close_price = int((open_price + self.grid_win * down_rate) / self.price_tick) * self.price_tick
open_price = int((down_line - self.grid_height * down_rate * i) / self.price_tick) * self.price_tick
close_price = int((open_price + self.grid_win * down_rate * i) / self.price_tick) * self.price_tick
grid = CtaGrid(direction=Direction.LONG,
open_price=open_price,
@ -686,8 +686,8 @@ class CtaGridTrade(CtaComponent):
if lots > 0:
# 做空,开仓价为上阻力线+网格高度*i平仓价为开仓价-止盈高度,开仓数量为缺省
for i in range(0, lots, 1):
open_price = int((upper_line + self.grid_height * upper_rate) / self.price_tick) * self.price_tick
close_price = int((open_price - self.grid_win * upper_rate) / self.price_tick) * self.price_tick
open_price = int((upper_line + self.grid_height * upper_rate * i) / self.price_tick) * self.price_tick
close_price = int((open_price - self.grid_win * upper_rate * i) / self.price_tick) * self.price_tick
grid = CtaGrid(direction=Direction.SHORT,
open_price=open_price,

View File

@ -2935,7 +2935,7 @@ class CtaLineBar(object):
if self.bar_len < maxLen:
return
dif, dea, macd = ta.MACD(np.append(self.close_array[-maxLen:], [self.line_bar[-1].close]),
dif, dea, macd = ta.MACD(np.append(self.close_array[-maxLen:], [self.line_bar[-1].close_price]),
fastperiod=self.para_macd_fast_len,
slowperiod=self.para_macd_slow_len, signalperiod=self.para_macd_signal_len)
@ -4008,7 +4008,7 @@ class CtaLineBar(object):
return
# 3、获取前InputN周期(包含当前周期的K线
last_bar_mid3 = (self.line_bar[-1].close_price + self.line_bar[-1].high_price + self.line_bar[-1].low_price) / 3
bar_mid3_ema10 = ta.EMA(np.append(self.mid3_array[-ema_len * 3:], [last_bar_mid3]), ema_len)[-1]
bar_mid3_ema10 = ta.EMA(np.append(self.mid3_array[-ema_len * 4:], [last_bar_mid3]), ema_len)[-1]
self._rt_yb = round(float(bar_mid3_ema10), self.round_n)
@property

View File

@ -262,7 +262,6 @@ class TdxFutureData(object):
self.symbol_exchange_dict.update({tdx_symbol: Tdx_Vn_Exchange_Map.get(str(tdx_market_id))})
self.symbol_market_dict.update({tdx_symbol: tdx_market_id})
# ----------------------------------------------------------------------
def get_bars(self,
symbol: str,

View File

@ -68,7 +68,8 @@ from vnpy.trader.utility import (
get_trading_date,
get_underlying_symbol,
round_to,
BarGenerator
BarGenerator,
print_dict
)
from vnpy.trader.event import EVENT_TIMER
@ -222,6 +223,12 @@ class CtpGateway(BaseGateway):
self.combiner_conf_dict = c.get_config()
if len(self.combiner_conf_dict) > 0:
self.write_log(u'加载的自定义价差/价比配置:{}'.format(self.combiner_conf_dict))
contract_dict = c.get_contracts()
for vt_symbol, contract in contract_dict.items():
contract.gateway_name = self.gateway_name
self.on_contract(contract)
except Exception as ex: # noqa
pass
if not self.td_api:

View File

@ -53,19 +53,21 @@ from vnpy.trader.constant import (
)
from vnpy.trader.utility import get_file_path
# 委托方式映射
ORDERTYPE_VT2IB = {
OrderType.LIMIT: "LMT",
OrderType.MARKET: "MKT",
OrderType.STOP: "STP"
OrderType.LIMIT: "LMT", # 限价单
OrderType.MARKET: "MKT", # 市场价
OrderType.STOP: "STP" # 停止价
}
ORDERTYPE_IB2VT = {v: k for k, v in ORDERTYPE_VT2IB.items()}
# 买卖方向映射
DIRECTION_VT2IB = {Direction.LONG: "BUY", Direction.SHORT: "SELL"}
DIRECTION_IB2VT = {v: k for k, v in DIRECTION_VT2IB.items()}
DIRECTION_IB2VT["BOT"] = Direction.LONG
DIRECTION_IB2VT["SLD"] = Direction.SHORT
# 交易所映射
EXCHANGE_VT2IB = {
Exchange.SMART: "SMART",
Exchange.NYMEX: "NYMEX",
@ -75,10 +77,13 @@ EXCHANGE_VT2IB = {
Exchange.ICE: "ICE",
Exchange.SEHK: "SEHK",
Exchange.HKFE: "HKFE",
Exchange.CFE: "CFE"
Exchange.CFE: "CFE",
Exchange.NYSE: "NYSE",
Exchange.NASDAQ: "NASDAQ"
}
EXCHANGE_IB2VT = {v: k for k, v in EXCHANGE_VT2IB.items()}
# 状态映射
STATUS_IB2VT = {
"ApiPending": Status.SUBMITTING,
"PendingSubmit": Status.SUBMITTING,
@ -90,6 +95,7 @@ STATUS_IB2VT = {
"Inactive": Status.REJECTED,
}
# 品种类型映射
PRODUCT_IB2VT = {
"STK": Product.EQUITY,
"CASH": Product.FOREX,
@ -99,14 +105,17 @@ PRODUCT_IB2VT = {
"FOT": Product.OPTION
}
# 期权映射
OPTION_VT2IB = {OptionType.CALL: "CALL", OptionType.PUT: "PUT"}
# 币种映射
CURRENCY_VT2IB = {
Currency.USD: "USD",
Currency.CNY: "CNY",
Currency.HKD: "HKD",
}
# tick字段映射
TICKFIELD_IB2VT = {
0: "bid_volume_1",
1: "bid_price_1",
@ -121,6 +130,7 @@ TICKFIELD_IB2VT = {
14: "open_price",
}
# 账号字段映射
ACCOUNTFIELD_IB2VT = {
"NetLiquidationByCurrency": "balance",
"NetLiquidation": "balance",
@ -129,12 +139,14 @@ ACCOUNTFIELD_IB2VT = {
"MaintMarginReq": "margin",
}
# 时间周期映射
INTERVAL_VT2IB = {
Interval.MINUTE: "1 min",
Interval.HOUR: "1 hour",
Interval.DAILY: "1 day",
}
# 合约连接符
JOIN_SYMBOL = "-"
@ -150,9 +162,9 @@ class IbGateway(BaseGateway):
exchanges = list(EXCHANGE_VT2IB.keys())
def __init__(self, event_engine):
def __init__(self, event_engine, gateway_name='IB'):
""""""
super().__init__(event_engine, "IB")
super().__init__(event_engine, gateway_name)
self.api = IbApi(self)
@ -483,13 +495,14 @@ class IbApi(EWrapper):
self.gateway.write_log(msg)
return
ib_size = contract.multiplier
if not ib_size:
try:
ib_size = int(contract.multiplier)
except ValueError:
ib_size = 1
price = averageCost / ib_size
pos = PositionData(
symbol=contract.conId,
symbol=generate_symbol(contract),
exchange=exchange,
direction=Direction.NET,
volume=position,
@ -807,7 +820,7 @@ class IbClient(EClient):
def generate_ib_contract(symbol: str, exchange: Exchange) -> Optional[Contract]:
""""""
"""生成ib合约"""
try:
fields = symbol.split(JOIN_SYMBOL)
@ -831,7 +844,7 @@ def generate_ib_contract(symbol: str, exchange: Exchange) -> Optional[Contract]:
def generate_symbol(ib_contract: Contract) -> str:
""""""
"""生成合约代码"""
fields = [ib_contract.symbol]
if ib_contract.secType in ["FUT", "OPT", "FOP"]:

File diff suppressed because it is too large Load Diff

View File

@ -67,7 +67,8 @@ from vnpy.trader.utility import (
get_trading_date,
get_underlying_symbol,
round_to,
BarGenerator
BarGenerator,
print_dict
)
from vnpy.trader.event import EVENT_TIMER
@ -214,6 +215,12 @@ class RohonGateway(BaseGateway):
self.combiner_conf_dict = c.get_config()
if len(self.combiner_conf_dict) > 0:
self.write_log(u'加载的自定义价差/价比配置:{}'.format(self.combiner_conf_dict))
contract_dict = c.get_contracts()
for vt_symbol, contract in contract_dict.items():
contract.gateway_name = self.gateway_name
self.on_contract(contract)
except Exception as ex: # noqa
pass

View File

@ -200,7 +200,7 @@ class XtpGateway(BaseGateway):
def process_timer_event(self, event) -> None:
""""""
self.count += 1
if self.count < 2:
if self.count < 5:
return
self.count = 0
@ -371,6 +371,8 @@ class XtpMdApi(MdApi):
min_volume=data["buy_qty_unit"],
gateway_name=self.gateway_name
)
#if contract.symbol.startswith('1230'):
# self.gateway.write_log(msg=f'合约信息:{contract.__dict__}')
self.gateway.on_contract(contract)
# 更新最新价

View File

@ -43,6 +43,7 @@ class Status(Enum):
CANCELLED = "已撤销"
CANCELLING = "撤销中"
REJECTED = "拒单"
UNKNOWN = "未知"
class Product(Enum):
@ -95,9 +96,12 @@ class Exchange(Enum):
SZSE = "SZSE" # Shenzhen Stock Exchange
SGE = "SGE" # Shanghai Gold Exchange
WXE = "WXE" # Wuxi Steel Exchange
CFETS = "CFETS" # China Foreign Exchange Trade System
# Global
SMART = "SMART" # Smart Router for US stocks
NYSE = "NYSE" # New York Stock Exchnage
NASDAQ = "NASDAQ" # Nasdaq Exchange
NYMEX = "NYMEX" # New York Mercantile Exchange
COMEX = "COMEX" # a division of theNew York Mercantile Exchange
GLOBEX = "GLOBEX" # Globex of CME

View File

@ -57,10 +57,11 @@ class OffsetConverter:
def get_position_holding(self, vt_symbol: str, gateway_name: str = '') -> "PositionHolding":
"""获取持仓信息"""
if len(gateway_name) == 0:
if gateway_name is None or len(gateway_name) == 0:
contract = self.main_engine.get_contract(vt_symbol)
if contract:
gateway_name = contract.gateway_name
k = f'{gateway_name}.{vt_symbol}'
holding = self.holdings.get(k, None)
if not holding:

View File

@ -331,13 +331,15 @@ class LocalOrderManager:
Management tool to support use local order id for trading.
"""
def __init__(self, gateway: BaseGateway, order_prefix: str = ""):
def __init__(self, gateway: BaseGateway, order_prefix: str = "", order_rjust:int = 8):
""""""
self.gateway: BaseGateway = gateway
# For generating local orderid
self.order_prefix: str = order_prefix
self.order_rjust: int = order_rjust
self.order_count: int = 0
self.orders: Dict[str, OrderData] = {} # local_orderid: order
# Map between local and system orderid
@ -362,7 +364,7 @@ class LocalOrderManager:
Generate a new local orderid.
"""
self.order_count += 1
local_orderid = self.order_prefix + str(self.order_count).rjust(8, "0")
local_orderid = self.order_prefix + str(self.order_count).rjust(self.order_rjust, "0")
return local_orderid
def get_local_orderid(self, sys_orderid: str) -> str:
@ -421,8 +423,11 @@ class LocalOrderManager:
def get_order_with_local_orderid(self, local_orderid: str) -> OrderData:
""""""
order = self.orders[local_orderid]
return copy(order)
order = self.orders.get(local_orderid, None)
if order:
return copy(order)
else:
return None
def on_order(self, order: OrderData) -> None:
"""

View File

@ -1041,7 +1041,7 @@ class ContractManager(QtWidgets.QWidget):
all_contracts = self.main_engine.get_all_contracts()
if flt:
contracts = [
contract for contract in all_contracts if flt in contract.vt_symbol
contract for contract in all_contracts if flt in contract.vt_symbol.lower()
]
else:
contracts = all_contracts

View File

@ -344,7 +344,7 @@ def get_csv_last_dt(file_name, dt_index=0, dt_format='%Y-%m-%d %H:%M:%S', line_l
return None
return None
def append_data(file_name: str, dict_data: dict, field_names: list = []):
def append_data(file_name: str, dict_data: dict, field_names: list = [], auto_header=True, encoding='utf8'):
"""
添加数据到csv文件中
:param file_name: csv的文件全路径
@ -354,15 +354,16 @@ def append_data(file_name: str, dict_data: dict, field_names: list = []):
dict_fieldnames = sorted(list(dict_data.keys())) if len(field_names) == 0 else field_names
try:
if not os.path.exists(file_name):
if not os.path.exists(file_name): # or os.path.getsize(file_name) == 0:
print(u'create csv file:{}'.format(file_name))
with open(file_name, 'a', encoding='utf8', newline='\n') as csvWriteFile:
writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel')
print(u'write csv header:{}'.format(dict_fieldnames))
writer.writeheader()
if auto_header:
print(u'write csv header:{}'.format(dict_fieldnames))
writer.writeheader()
writer.writerow(dict_data)
else:
with open(file_name, 'a', encoding='utf8', newline='\n') as csvWriteFile:
with open(file_name, 'a', encoding=encoding, newline='\n') as csvWriteFile:
writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel',
extrasaction='ignore')
writer.writerow(dict_data)