[bug fix]

This commit is contained in:
msincenselee 2020-03-19 23:34:43 +08:00
parent 8bec0006ac
commit d228ca9559
11 changed files with 91 additions and 36 deletions

View File

@ -37,6 +37,16 @@
conda config --set show_channel_urls yes conda config --set show_channel_urls yes
conda install -c quantopian ta-lib=0.4.9 conda install -c quantopian ta-lib=0.4.9
若出现libta_lib.so.0 cannot open shared object file no such file or directory
解决:
sudo find / -name libta_lib.so.0
/home/ai/eco-ta/ta-lib/src/.libs/libta_lib.so.0
/usr/local/lib/libta_lib.so.0
vi /etc/profile
添加
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib
source /etc/profile
9、数字货币的增量安装 9、数字货币的增量安装
conda install scipy conda install scipy

View File

@ -316,7 +316,7 @@ class AlgoEngine(BaseEngine):
"""查询合约的size""" """查询合约的size"""
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
if contract is None: if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息') self.write_error(f'get_size 查询不到{vt_symbol}合约信息')
return 10 return 10
return contract.size return contract.size
@ -325,7 +325,7 @@ class AlgoEngine(BaseEngine):
"""查询保证金比率""" """查询保证金比率"""
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
if contract is None: if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息') self.write_error(f'get_margin_rate 查询不到{vt_symbol}合约信息')
return 0.1 return 0.1
if contract.margin_rate == 0: if contract.margin_rate == 0:
return 0.1 return 0.1
@ -336,7 +336,7 @@ class AlgoEngine(BaseEngine):
"""查询价格最小跳动""" """查询价格最小跳动"""
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
if contract is None: if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息') self.write_error(f'get_price_tick 查询不到{vt_symbol}合约信息')
return 0.1 return 0.1
return contract.pricetick return contract.pricetick

View File

@ -1762,7 +1762,7 @@ class BackTestingEngine(object):
holding_profit = 0 holding_profit = 0
last_price = self.get_price(symbol) last_price = self.get_price(symbol)
if last_price is not None: if last_price is not None:
holding_profit = (last_price - longpos.price) * longpos.volume * self.get_size(symbol) holding_profit = (last_price - longpos.price) * longpos.volume
long_pos_occupy_money += last_price * abs(longpos.volume) * self.get_margin_rate(symbol) long_pos_occupy_money += last_price * abs(longpos.volume) * self.get_margin_rate(symbol)
# 账号的持仓盈亏 # 账号的持仓盈亏
@ -1780,7 +1780,7 @@ class BackTestingEngine(object):
holding_profit = 0 holding_profit = 0
last_price = self.get_price(symbol) last_price = self.get_price(symbol)
if last_price is not None: if last_price is not None:
holding_profit = (shortpos.price - last_price) * shortpos.volume * self.get_size(symbol) holding_profit = (shortpos.price - last_price) * shortpos.volume
short_pos_occupy_money += last_price * abs(shortpos.volume) * self.get_margin_rate(symbol) short_pos_occupy_money += last_price * abs(shortpos.volume) * self.get_margin_rate(symbol)
# 账号的持仓盈亏 # 账号的持仓盈亏

View File

@ -1021,17 +1021,19 @@ class CtaEngine(BaseEngine):
strategy = self.strategies[strategy_name] strategy = self.strategies[strategy_name]
if not strategy.inited: if not strategy.inited:
self.write_error(f"策略{strategy.strategy_name}启动失败,请先初始化") self.write_error(f"策略{strategy.strategy_name}启动失败,请先初始化")
return return False
if strategy.trading: if strategy.trading:
self.write_error(f"{strategy_name}已经启动,请勿重复操作") self.write_error(f"{strategy_name}已经启动,请勿重复操作")
return return False
self.call_strategy_func(strategy, strategy.on_start) self.call_strategy_func(strategy, strategy.on_start)
strategy.trading = True strategy.trading = True
self.put_strategy_event(strategy) self.put_strategy_event(strategy)
return True
def stop_strategy(self, strategy_name: str): def stop_strategy(self, strategy_name: str):
""" """
Stop a strategy. Stop a strategy.
@ -1039,7 +1041,7 @@ class CtaEngine(BaseEngine):
strategy = self.strategies[strategy_name] strategy = self.strategies[strategy_name]
if not strategy.trading: if not strategy.trading:
self.write_log(f'{strategy_name}策略实例已处于停止交易状态') self.write_log(f'{strategy_name}策略实例已处于停止交易状态')
return return False
# Call on_stop function of the strategy # Call on_stop function of the strategy
self.write_log(f'调用{strategy_name}的on_stop,停止交易') self.write_log(f'调用{strategy_name}的on_stop,停止交易')
@ -1059,6 +1061,8 @@ class CtaEngine(BaseEngine):
# Update GUI # Update GUI
self.put_strategy_event(strategy) self.put_strategy_event(strategy)
return True
def edit_strategy(self, strategy_name: str, setting: dict): def edit_strategy(self, strategy_name: str, setting: dict):
""" """
Edit parameters of a strategy. Edit parameters of a strategy.
@ -1080,7 +1084,7 @@ class CtaEngine(BaseEngine):
strategy = self.strategies[strategy_name] strategy = self.strategies[strategy_name]
if strategy.trading: if strategy.trading:
self.write_error(f"策略{strategy.strategy_name}移除失败,请先停止") self.write_error(f"策略{strategy.strategy_name}移除失败,请先停止")
return return False
# Remove setting # Remove setting
self.remove_strategy_setting(strategy_name) self.remove_strategy_setting(strategy_name)
@ -1325,7 +1329,7 @@ class CtaEngine(BaseEngine):
:param strategy_name: :param strategy_name:
:return: :return:
""" """
return [{k: {'inited': v.inited, 'trading': v.trading}} for k, v in self.strategies.items()] return {k: {'inited': v.inited, 'trading': v.trading} for k, v in self.strategies.items()}
def get_strategy_pos(self, name, strategy=None): def get_strategy_pos(self, name, strategy=None):
""" """

View File

@ -366,6 +366,8 @@ class BackTestingEngine(object):
def get_position_holding(self, vt_symbol: str, gateway_name: str = ''): def get_position_holding(self, vt_symbol: str, gateway_name: str = ''):
""" 查询合约在账号的持仓(包含多空)""" """ 查询合约在账号的持仓(包含多空)"""
if gateway_name:
gateway_name = self.gateway_name
k = f'{gateway_name}.{vt_symbol}' k = f'{gateway_name}.{vt_symbol}'
holding = self.holdings.get(k, None) holding = self.holdings.get(k, None)
if not holding: if not holding:
@ -1070,7 +1072,7 @@ class BackTestingEngine(object):
strategy.on_stop_order(stop_order) strategy.on_stop_order(stop_order)
strategy.on_order(order) strategy.on_order(order)
self.append_trade(trade) self.append_trade(trade)
holding = self.get_position_holding(vt_symbol=trade.vt_symbol) holding = self.get_position_holding(vt_symbol=trade.vt_symbol, gateway_name=self.gateway_name)
holding.update_trade(trade) holding.update_trade(trade)
strategy.on_trade(trade) strategy.on_trade(trade)
@ -1154,14 +1156,11 @@ class BackTestingEngine(object):
self.write_log(u'vt_trade_id:{0}'.format(cov_trade.vt_tradeid)) self.write_log(u'vt_trade_id:{0}'.format(cov_trade.vt_tradeid))
# 更新持仓缓存数据 # 更新持仓缓存数据
pos_buffer = self.pos_holding_dict.get(cov_trade.vt_symbol, None) holding = self.get_position_holding(cov_trade.vt_symbol, self.gateway_name)
if not pos_buffer: holding.update_trade(cov_trade)
pos_buffer = PositionHolding(self.get_contract(vt_symbol))
self.pos_holding_dict[cov_trade.vt_symbol] = pos_buffer
pos_buffer.update_trade(cov_trade)
self.write_log(u'{} : crossLimitOrder: TradeId:{}, posBuffer = {}'.format(cov_trade.strategy_name, self.write_log(u'{} : crossLimitOrder: TradeId:{}, posBuffer = {}'.format(cov_trade.strategy_name,
cov_trade.tradeid, cov_trade.tradeid,
pos_buffer.to_str())) holding.to_str()))
# 写入交易记录 # 写入交易记录
self.append_trade(cov_trade) self.append_trade(cov_trade)

View File

@ -141,6 +141,7 @@ class CtaEngine(BaseEngine):
self.load_strategy_class() self.load_strategy_class()
self.load_strategy_setting() self.load_strategy_setting()
self.register_event() self.register_event()
self.register_funcs()
self.write_log("CTA策略引擎初始化成功") self.write_log("CTA策略引擎初始化成功")
def close(self): def close(self):
@ -173,18 +174,17 @@ class CtaEngine(BaseEngine):
self.main_engine.save_strategy_snapshot = self.save_strategy_snapshot self.main_engine.save_strategy_snapshot = self.save_strategy_snapshot
# 注册到远程服务调用 # 注册到远程服务调用
rpc_service = self.main_engine.apps.get('RpcService') if self.main_engine.rpc_service:
if rpc_service: self.main_engine.rpc_service.register(self.main_engine.get_strategy_status)
rpc_service.register(self.main_engine.get_strategy_status) self.main_engine.rpc_service.register(self.main_engine.get_strategy_pos)
rpc_service.register(self.main_engine.get_strategy_pos) self.main_engine.rpc_service.register(self.main_engine.add_strategy)
rpc_service.register(self.main_engine.add_strategy) self.main_engine.rpc_service.register(self.main_engine.init_strategy)
rpc_service.register(self.main_engine.init_strategy) self.main_engine.rpc_service.register(self.main_engine.start_strategy)
rpc_service.register(self.main_engine.start_strategy) self.main_engine.rpc_service.register(self.main_engine.stop_strategy)
rpc_service.register(self.main_engine.stop_strategy) self.main_engine.rpc_service.register(self.main_engine.remove_strategy)
rpc_service.register(self.main_engine.remove_strategy) self.main_engine.rpc_service.register(self.main_engine.reload_strategy)
rpc_service.register(self.main_engine.reload_strategy) self.main_engine.rpc_service.register(self.main_engine.save_strategy_data)
rpc_service.register(self.main_engine.save_strategy_data) self.main_engine.rpc_service.register(self.main_engine.save_strategy_snapshot)
rpc_service.register(self.main_engine.save_strategy_snapshot)
def process_timer_event(self, event: Event): def process_timer_event(self, event: Event):
""" 处理定时器事件""" """ 处理定时器事件"""

View File

@ -279,7 +279,7 @@ class BinancefRestApi(RestClient):
self.start(session_number) self.start(session_number)
self.gateway.write_log("REST API启动成功") self.gateway.write_log("REST API启动成功")
self.gateway.status.update({'md_con': True, 'md_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')}) self.gateway.status.update({'td_con': True, 'td_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
self.query_time() self.query_time()
self.query_account() self.query_account()
@ -922,7 +922,7 @@ class BinancefDataWebsocketApi(WebsocketClient):
def on_connected(self) -> None: def on_connected(self) -> None:
"""""" """"""
self.gateway.write_log("行情Websocket API连接刷新") self.gateway.write_log("行情Websocket API连接刷新")
self.gateway.status.update({'mdws_con': True, 'mdws_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')}) self.gateway.status.update({'md_con': True, 'md_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def subscribe(self, req: SubscribeRequest) -> None: def subscribe(self, req: SubscribeRequest) -> None:
"""""" """"""

View File

@ -801,7 +801,7 @@ class CtpTdApi(TdApi):
account.commission = round(float(data['Commission']), 7) account.commission = round(float(data['Commission']), 7)
account.margin = round(float(data['CurrMargin']), 7) account.margin = round(float(data['CurrMargin']), 7)
account.close_profit = round(float(data['CloseProfit']), 7) account.close_profit = round(float(data['CloseProfit']), 7)
account.holding_profit = round(float(data['PositionProfit']),7) account.holding_profit = round(float(data['PositionProfit']), 7)
account.trading_day = str(data['TradingDay']) account.trading_day = str(data['TradingDay'])
if '-' not in account.trading_day and len(account.trading_day) == 8: if '-' not in account.trading_day and len(account.trading_day) == 8:
account.trading_day = '-'.join( account.trading_day = '-'.join(

View File

@ -310,7 +310,7 @@ class RpcClient:
while self.__active: while self.__active:
if not self.__socket_sub.poll(pull_tolerance): if not self.__socket_sub.poll(pull_tolerance):
self._on_unexpected_disconnected() #self._on_unexpected_disconnected()
continue continue
# Receive data from subscribe socket # Receive data from subscribe socket

View File

@ -171,6 +171,9 @@ class PositionHolding:
if self.short_td < 0: if self.short_td < 0:
self.short_yd += self.short_td self.short_yd += self.short_td
self.short_td = 0 self.short_td = 0
self.short_yd = round(self.short_yd, 7)
self.short_td = round(self.short_td, 7)
else: else:
if trade.offset == Offset.OPEN: if trade.offset == Offset.OPEN:
self.short_td += trade.volume self.short_td += trade.volume
@ -187,9 +190,11 @@ class PositionHolding:
if self.long_td < 0: if self.long_td < 0:
self.long_yd += self.long_td self.long_yd += self.long_td
self.long_td = 0 self.long_td = 0
self.long_td = round(self.long_td, 7)
self.long_yd = round(self.long_yd, 7)
self.long_pos = self.long_td + self.long_yd self.long_pos = round(self.long_td + self.long_yd, 7)
self.short_pos = self.short_td + self.short_yd self.short_pos = round(self.short_td + self.short_yd, 7)
def calculate_frozen(self) -> None: def calculate_frozen(self) -> None:
"""""" """"""

View File

@ -66,6 +66,7 @@ class MainEngine:
self.rm_engine = None self.rm_engine = None
self.algo_engine = None self.algo_engine = None
self.rpc_service = None
os.chdir(TRADER_DIR) # Change working directory os.chdir(TRADER_DIR) # Change working directory
self.init_engines() # Initialize function engines self.init_engines() # Initialize function engines
@ -111,6 +112,8 @@ class MainEngine:
self.rm_engine = engine self.rm_engine = engine
elif app.app_name == "AlgoTrading": elif app.app_name == "AlgoTrading":
self.algo_engine == engine self.algo_engine == engine
elif app.app_name == 'RpcService':
self.rpc_service = engine
return engine return engine
@ -262,6 +265,9 @@ class MainEngine:
Make sure every gateway and app is closed properly before Make sure every gateway and app is closed properly before
programme exit. programme exit.
""" """
if hasattr(self, 'save_contracts'):
self.save_contracts()
# Stop event engine first to prevent new timer event. # Stop event engine first to prevent new timer event.
self.event_engine.stop() self.event_engine.stop()
@ -315,7 +321,7 @@ class BaseEngine(ABC):
msg = f'[{source}]{msg}' msg = f'[{source}]{msg}'
self.logger.log(level, msg) self.logger.log(level, msg)
else: else:
log = LogData(msg=msg, level=level) log = LogData(msg=msg, level=level, gateway_name='')
event = Event(EVENT_LOG, log) event = Event(EVENT_LOG, log)
self.event_engine.put(event) self.event_engine.put(event)
@ -414,6 +420,7 @@ class OmsEngine(BaseEngine):
self.positions: Dict[str, PositionData] = {} self.positions: Dict[str, PositionData] = {}
self.accounts: Dict[str, AccountData] = {} self.accounts: Dict[str, AccountData] = {}
self.contracts: Dict[str, ContractData] = {} self.contracts: Dict[str, ContractData] = {}
self.today_contracts: Dict[str, ContractData] = {}
self.custom_contracts = {} self.custom_contracts = {}
self.prices = {} self.prices = {}
@ -422,6 +429,33 @@ class OmsEngine(BaseEngine):
self.add_function() self.add_function()
self.register_event() self.register_event()
def __del__(self):
"""保存缓存"""
self.save_contracts()
def load_contracts(self) -> None:
"""从本地缓存加载合约字典"""
import bz2
import pickle
contract_file_name = 'vn_contract.pkb2'
if not os.path.exists(contract_file_name):
return
with bz2.BZ2File(contract_file_name, 'rb') as f:
self.contracts = pickle.load(f)
self.write_log(f'加载缓存合约字典:{contract_file_name}')
def save_contracts(self) -> None:
"""持久化合约对象到缓存文件"""
import bz2
import pickle
contract_file_name = 'vn_contract.pkb2'
with bz2.BZ2File(contract_file_name, 'wb') as f:
if len(self.today_contracts) > 0:
self.write_log(f'保存今日合约对象到缓存文件')
pickle.dump(self.today_contracts, f)
else:
pickle.dump(self.contracts, f)
def add_function(self) -> None: def add_function(self) -> None:
"""Add query function to main engine.""" """Add query function to main engine."""
self.main_engine.get_tick = self.get_tick self.main_engine.get_tick = self.get_tick
@ -439,6 +473,7 @@ class OmsEngine(BaseEngine):
self.main_engine.get_all_contracts = self.get_all_contracts self.main_engine.get_all_contracts = self.get_all_contracts
self.main_engine.get_all_active_orders = self.get_all_active_orders self.main_engine.get_all_active_orders = self.get_all_active_orders
self.main_engine.get_all_custom_contracts = self.get_all_custom_contracts self.main_engine.get_all_custom_contracts = self.get_all_custom_contracts
self.main_engine.save_contracts = self.save_contracts
def register_event(self) -> None: def register_event(self) -> None:
"""""" """"""
@ -489,6 +524,8 @@ class OmsEngine(BaseEngine):
contract = event.data contract = event.data
self.contracts[contract.vt_symbol] = contract self.contracts[contract.vt_symbol] = contract
self.contracts[contract.symbol] = contract self.contracts[contract.symbol] = contract
self.today_contracts[contract.vt_symbol] = contract
self.today_contracts[contract.symbol] = contract
def get_tick(self, vt_symbol: str) -> Optional[TickData]: def get_tick(self, vt_symbol: str) -> Optional[TickData]:
""" """