From 7ce42e54a412b944a5b1f863d0a95b8e0a4af07d Mon Sep 17 00:00:00 2001 From: msincenselee Date: Sun, 9 Feb 2020 19:13:09 +0800 Subject: [PATCH] [update] --- vnpy/app/cta_strategy_pro/cta_line_bar.py | 2 + vnpy/app/cta_strategy_pro/engine.py | 133 ++++++++++++++++-- .../app/cta_strategy_pro/portfolio_testing.py | 7 +- vnpy/app/cta_strategy_pro/template.py | 17 ++- vnpy/app/cta_strategy_pro/ui/widget.py | 15 +- vnpy/trader/utility.py | 56 ++++++++ 6 files changed, 205 insertions(+), 25 deletions(-) diff --git a/vnpy/app/cta_strategy_pro/cta_line_bar.py b/vnpy/app/cta_strategy_pro/cta_line_bar.py index c0489f9f..c6c02259 100644 --- a/vnpy/app/cta_strategy_pro/cta_line_bar.py +++ b/vnpy/app/cta_strategy_pro/cta_line_bar.py @@ -210,6 +210,8 @@ class CtaLineBar(object): def init_param_list(self): self.paramList.append('bar_interval') self.paramList.append('interval') + self.paramList.append('mode') + self.paramList.append('para_pre_len') self.paramList.append('para_ma1_len') self.paramList.append('para_ma2_len') diff --git a/vnpy/app/cta_strategy_pro/engine.py b/vnpy/app/cta_strategy_pro/engine.py index 0281c655..fb2387c3 100644 --- a/vnpy/app/cta_strategy_pro/engine.py +++ b/vnpy/app/cta_strategy_pro/engine.py @@ -133,7 +133,8 @@ class CtaEngine(BaseEngine): self.stop_order_count = 0 # for generating stop_orderid self.stop_orders = {} # stop_orderid: stop_order - self.init_executor = ThreadPoolExecutor(max_workers=1) + self.thread_executor = ThreadPoolExecutor(max_workers=1) + self.thread_tasks = [] self.vt_tradeids = set() # for filtering duplicate trade @@ -722,7 +723,7 @@ class CtaEngine(BaseEngine): # 添加 策略名 strategy_name <=> 合约订阅 vt_symbol 的映射 subscribe_symbol_set = self.strategy_symbol_map[strategy.strategy_name] - subscribe_symbol_set.add(contract.vt_symbol) + subscribe_symbol_set.add(vt_symbol) return True @@ -777,7 +778,8 @@ class CtaEngine(BaseEngine): accounts = self.main_engine.get_all_accounts() if len(accounts) > 0: account = accounts[0] - return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01), 2), 100 + return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01), + 2), 100 else: return 0, 0, 0, 0 @@ -858,7 +860,7 @@ class CtaEngine(BaseEngine): subscribe_symbol_set.add(vt_symbol) # Update to setting file. - self.update_strategy_setting(strategy_name, setting) + self.update_strategy_setting(strategy_name, setting, auto_init, auto_start) self.put_strategy_event(strategy) @@ -870,7 +872,8 @@ class CtaEngine(BaseEngine): """ Init a strategy. """ - self.init_executor.submit(self._init_strategy, strategy_name, auto_start) + task = self.thread_executor.submit(self._init_strategy, strategy_name, auto_start) + self.thread_tasks.append(task) def _init_strategy(self, strategy_name: str, auto_start: bool = False): """ @@ -959,9 +962,12 @@ class CtaEngine(BaseEngine): 风险警示: 该方法强行干预策略的配置 """ strategy = self.strategies[strategy_name] + auto_init = setting.pop('auto_init', False) + auto_start = setting.pop('auto_start', False) + strategy.update_setting(setting) - self.update_strategy_setting(strategy_name, setting) + self.update_strategy_setting(strategy_name, setting, auto_init, auto_start) self.put_strategy_event(strategy) def remove_strategy(self, strategy_name: str): @@ -1012,7 +1018,13 @@ class CtaEngine(BaseEngine): if strategy_name not in self.strategies or strategy_name not in self.strategy_setting: self.write_error(f"{strategy_name}不在运行策略中,不能重启") return False - old_strategy_config = copy(self.strategy_setting[strategy_name]) + + # 从本地配置文件中读取 + if len(setting) == 0: + strategies_setting = load_json(self.setting_filename) + old_strategy_config = strategies_setting.get(strategy_name, {}) + else: + old_strategy_config = copy(self.strategy_setting[strategy_name]) class_name = old_strategy_config.get('class_name') if len(vt_symbol) == 0: @@ -1035,11 +1047,98 @@ class CtaEngine(BaseEngine): self.add_strategy(class_name=class_name, strategy_name=strategy_name, vt_symbol=vt_symbol, - setting=setting) + setting=setting, + auto_init=old_strategy_config.get('auto_init', False), + auto_start=old_strategy_config.get('auto_start', False)) self.write_log(f'重新运行策略{strategy_name}执行完毕') return True + def save_strategy_data(self, select_name: str): + """ save strategy data""" + has_executed = False + msg = "" + # 1.判断策略名称是否存在字典中 + for strategy_name in list(self.strategies.keys()): + if select_name != 'ALL': + if strategy_name != select_name: + continue + # 2.提取策略 + strategy = self.strategies.get(strategy_name, None) + if not strategy: + continue + + # 3.判断策略是否运行 + if strategy.inited and strategy.trading: + task = self.thread_executor.submit(self.thread_save_strategy_data, strategy_name) + self.thread_tasks.append(task) + msg += f'{strategy_name}执行保存数据\n' + has_executed = True + else: + self.write_log(f'{strategy_name}未初始化/未启动交易,不进行保存数据') + return has_executed, msg + + def thread_save_strategy_data(self, strategy_name): + """异步线程保存策略数据""" + strategy = self.strategies.get(strategy_name, None) + if strategy is None: + return + try: + # 保存策略数据 + strategy.sync_data() + except Exception as ex: + self.write_error(u'保存策略{}数据异常:'.format(strategy_name, str(ex))) + self.write_error(traceback.format_exc()) + + def save_strategy_snapshot(self, select_name: str): + """ + 保存策略K线切片数据 + :param select_name: + :return: + """ + has_executed = False + msg = "" + # 1.判断策略名称是否存在字典中 + for strategy_name in list(self.strategies.keys()): + if select_name != 'ALL': + if strategy_name != select_name: + continue + # 2.提取策略 + strategy = self.strategies.get(strategy_name, None) + if not strategy: + continue + + if not hasattr(strategy, 'get_klines_snapshot'): + continue + + # 3.判断策略是否运行 + if strategy.inited and strategy.trading: + task = self.thread_executor.submit(self.thread_save_strategy_snapshot, strategy_name) + self.thread_tasks.append(task) + msg += f'{strategy_name}执行保存K线切片\n' + has_executed = True + + return has_executed, msg + + def thread_save_strategy_snapshot(self, strategy_name): + """异步线程保存策略切片""" + strategy = self.strategies.get(strategy_name, None) + if strategy is None: + return + + try: + # 5.保存策略切片 + snapshot = strategy.get_klines_snapshot() + if len(snapshot) == 0: + self.write_log(f'{strategy_name}返回得K线切片数据为空') + return + + # 剩下工作:保存本地文件/数据库 + + except Exception as ex: + self.write_error(u'获取策略{}切片数据异常:'.format(strategy_name, str(ex))) + self.write_error(traceback.format_exc()) + def load_strategy_class(self): """ Load strategy class from source code. @@ -1209,7 +1308,7 @@ class CtaEngine(BaseEngine): # SPD合约 spd_vt_symbol = pos.get('vt_symbol', None) if spd_vt_symbol is not None and spd_vt_symbol.endswith('SPD'): - spd_symbol,spd_exchange = extract_vt_symbol(spd_vt_symbol) + spd_symbol, spd_exchange = extract_vt_symbol(spd_vt_symbol) spd_setting = self.main_engine.get_all_custom_contracts().get(spd_symbol, None) if spd_setting is None: @@ -1225,7 +1324,7 @@ class CtaEngine(BaseEngine): leg1_pos.update({'symbol': spd_setting.get('leg1_symbol')}) leg1_pos.update({'vt_symbol': spd_setting.get('leg1_symbol')}) leg1_pos.update({'direction': leg1_direction}) - leg1_pos.update({'volume': spd_setting.get('leg1_ratio', 1)*spd_volume}) + leg1_pos.update({'volume': spd_setting.get('leg1_ratio', 1) * spd_volume}) leg2_pos = {} leg2_pos.update({'symbol': spd_setting.get('leg2_symbol')}) @@ -1288,7 +1387,12 @@ class CtaEngine(BaseEngine): Get parameters of a strategy. """ strategy = self.strategies[strategy_name] - return strategy.get_parameters() + strategy_config = self.strategy_setting.get(strategy_name, {}) + d = {} + d.update({'auto_init': strategy_config.get('auto_init', False)}) + d.update({'auto_start': strategy_config.get('auto_start', False)}) + d.update(strategy.get_parameters()) + return d def init_all_strategies(self): """ @@ -1328,7 +1432,8 @@ class CtaEngine(BaseEngine): auto_start=strategy_config.get('auto_start', False) ) - def update_strategy_setting(self, strategy_name: str, setting: dict): + def update_strategy_setting(self, strategy_name: str, setting: dict, auto_init: bool = False, + auto_start: bool = False): """ Update setting file. """ @@ -1339,8 +1444,8 @@ class CtaEngine(BaseEngine): self.strategy_setting[strategy_name] = { "class_name": strategy.__class__.__name__, "vt_symbol": strategy.vt_symbol, - "auto_init": strategy_config.get('auto_init', False), - "auto_start": strategy_config.get('auto_start', False), + "auto_init": auto_init, + "auto_start": auto_start, "setting": setting } save_json(self.setting_filename, self.strategy_setting) diff --git a/vnpy/app/cta_strategy_pro/portfolio_testing.py b/vnpy/app/cta_strategy_pro/portfolio_testing.py index 65240297..e4979474 100644 --- a/vnpy/app/cta_strategy_pro/portfolio_testing.py +++ b/vnpy/app/cta_strategy_pro/portfolio_testing.py @@ -897,7 +897,7 @@ class PortfolioTestingEngine(object): """保存策略数据""" for strategy in self.strategies.values(): self.write_log(u'save strategy data') - strategy.saveData() + strategy.save_data() def send_order(self, strategy: CtaTemplate, @@ -2274,7 +2274,10 @@ class PortfolioTestingEngine(object): d = OrderedDict() try: for k in trade_fields: - d[k] = getattr(trade, k, '') + if k in ['exchange', 'direction', 'offset']: + d[k] = getattr(trade, k).value + else: + d[k] = getattr(trade, k, '') trade_file = os.path.abspath(os.path.join(self.get_logs_path(), '{}_trade.csv'.format(strategy_name))) self.append_data(file_name=trade_file, dict_data=d) diff --git a/vnpy/app/cta_strategy_pro/template.py b/vnpy/app/cta_strategy_pro/template.py index 3199d370..c3296eda 100644 --- a/vnpy/app/cta_strategy_pro/template.py +++ b/vnpy/app/cta_strategy_pro/template.py @@ -526,6 +526,8 @@ class CtaProTemplate(CtaTemplate): 增强模板 """ + backtesting = False + # 逻辑过程日志 dist_fieldnames = ['datetime', 'symbol', 'volume', 'price', 'operation', 'signal', 'stop_price', 'target_price', @@ -552,8 +554,6 @@ class CtaProTemplate(CtaTemplate): self.cancel_seconds = 120 # 撤单时间(秒) - self.backtesting = False - self.klines = {} # K线字典: kline_name: kline # 增加仓位管理模块 @@ -601,7 +601,11 @@ class CtaProTemplate(CtaTemplate): with bz2.BZ2File(file_name, 'wb') as f: klines = {} for kline_name in kline_names: - klines.update({kline_name: self.klines.get(kline_name, None)}) + kline = self.klines.get(kline_name, None) + #if kline: + # kline.strategy = None + # kline.cb_on_bar = None + klines.update({kline_name: kline}) pickle.dump(klines, f) def load_klines_from_cache(self, kline_names: list = []): @@ -657,11 +661,11 @@ class CtaProTemplate(CtaTemplate): 'strategy': self.strategy_name, 'datetime': datetime.now()} - for kline_name in self.klines.keys(): + for kline_name in sorted(self.klines.keys()): d.update({kline_name: self.klines.get(kline_name).get_data()}) return d except Exception as ex: - self.write_error(u'获取klines切片数据失败') + self.write_error(f'获取klines切片数据失败:{str(ex)}') return {} def init_position(self): @@ -959,8 +963,9 @@ class CtaProTemplate(CtaTemplate): else: save_path = self.cta_engine.get_data_path() try: - if self.position: + if self.position and 'long_pos' not in dist_data: dist_data.update({'long_pos': self.position.long_pos}) + if self.position and 'short_pos' not in dist_data: dist_data.update({'short_pos': self.position.short_pos}) file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_dist.csv')) diff --git a/vnpy/app/cta_strategy_pro/ui/widget.py b/vnpy/app/cta_strategy_pro/ui/widget.py index ab99b68d..935f912d 100644 --- a/vnpy/app/cta_strategy_pro/ui/widget.py +++ b/vnpy/app/cta_strategy_pro/ui/widget.py @@ -144,9 +144,10 @@ class CtaManager(QtWidgets.QWidget): setting = editor.get_setting() vt_symbol = setting.pop("vt_symbol") strategy_name = setting.pop("strategy_name") - + auto_init = setting.pop("auto_init", False) + auto_start = setting.pop("auto_start", False) self.cta_engine.add_strategy( - class_name, strategy_name, vt_symbol, setting + class_name, strategy_name, vt_symbol, setting, auto_init, auto_start ) def clear_log(self): @@ -201,6 +202,9 @@ class StrategyManager(QtWidgets.QFrame): reload_button = QtWidgets.QPushButton("重载") reload_button.clicked.connect(self.reload_strategy) + save_button = QtWidgets.QPushButton("保存") + save_button.clicked.connect(self.save_strategy) + strategy_name = self._data["strategy_name"] vt_symbol = self._data["vt_symbol"] class_name = self._data["class_name"] @@ -222,6 +226,7 @@ class StrategyManager(QtWidgets.QFrame): hbox.addWidget(edit_button) hbox.addWidget(remove_button) hbox.addWidget(reload_button) + hbox.addWidget(save_button) vbox = QtWidgets.QVBoxLayout() vbox.addWidget(label) @@ -273,6 +278,9 @@ class StrategyManager(QtWidgets.QFrame): """重新加载策略""" self.cta_engine.reload_strategy(self.strategy_name) + def save_strategy(self): + self.cta_engine.save_strategy_data(self.strategy_name) + class DataMonitor(QtWidgets.QTableWidget): """ @@ -403,8 +411,9 @@ class SettingEditor(QtWidgets.QDialog): if self.class_name: self.setWindowTitle(f"添加策略:{self.class_name}") button_text = "添加" - parameters = {"strategy_name": "", "vt_symbol": ""} + parameters = {"strategy_name": "", "vt_symbol": "", "auto_init": True, "auto_start": True} parameters.update(self.parameters) + else: self.setWindowTitle(f"参数编辑:{self.strategy_name}") button_text = "确定" diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 091fb49a..40f52b41 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -1219,3 +1219,59 @@ def get_file_logger(filename: str): handler.setFormatter(log_formatter) logger.addHandler(handler) # each handler will be added only once. return logger + + +def get_bars(csv_file: str, + symbol: str, + exchange: Exchange, + start_date: datetime = None, + end_date: datetime = None,): + """ + 获取bar + 数据存储目录: 项目/bar_data + :param csv_file: csv文件路径 + :param symbol: 合约 + :param exchange 交易所 + :param start_date: datetime + :param end_date: datetime + :return: + """ + bars = [] + + import csv + with open(file=csv_file, mode='r', encoding='utf8', newline='\n') as f: + reader = csv.DictReader(f) + + count = 0 + + for item in reader: + + dt = datetime.strptime(item['datetime'], '%Y-%m-%d %H:%M:%S') + if start_date: + if dt < start_date: + continue + if end_date: + if dt > end_date: + break + + bar = BarData( + symbol=symbol, + exchange=exchange, + datetime=dt, + interval=Interval.MINUTE, + volume=float(item['volume']), + open_price=float(item['open']), + high_price=float(item['high']), + low_price=float(item['low']), + close_price=float(item['close']), + open_interest=float(item['open_interest']), + trading_day=item['trading_day'], + gateway_name="Tdx", + ) + + bars.append(bar) + + # do some statistics + count += 1 + + return bars