This commit is contained in:
msincenselee 2020-02-09 19:13:09 +08:00
parent b439a419e7
commit 7ce42e54a4
6 changed files with 205 additions and 25 deletions

View File

@ -210,6 +210,8 @@ class CtaLineBar(object):
def init_param_list(self):
self.paramList.append('bar_interval')
self.paramList.append('interval')
self.paramList.append('mode')
self.paramList.append('para_pre_len')
self.paramList.append('para_ma1_len')
self.paramList.append('para_ma2_len')

View File

@ -133,7 +133,8 @@ class CtaEngine(BaseEngine):
self.stop_order_count = 0 # for generating stop_orderid
self.stop_orders = {} # stop_orderid: stop_order
self.init_executor = ThreadPoolExecutor(max_workers=1)
self.thread_executor = ThreadPoolExecutor(max_workers=1)
self.thread_tasks = []
self.vt_tradeids = set() # for filtering duplicate trade
@ -722,7 +723,7 @@ class CtaEngine(BaseEngine):
# 添加 策略名 strategy_name <=> 合约订阅 vt_symbol 的映射
subscribe_symbol_set = self.strategy_symbol_map[strategy.strategy_name]
subscribe_symbol_set.add(contract.vt_symbol)
subscribe_symbol_set.add(vt_symbol)
return True
@ -777,7 +778,8 @@ class CtaEngine(BaseEngine):
accounts = self.main_engine.get_all_accounts()
if len(accounts) > 0:
account = accounts[0]
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01), 2), 100
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01),
2), 100
else:
return 0, 0, 0, 0
@ -858,7 +860,7 @@ class CtaEngine(BaseEngine):
subscribe_symbol_set.add(vt_symbol)
# Update to setting file.
self.update_strategy_setting(strategy_name, setting)
self.update_strategy_setting(strategy_name, setting, auto_init, auto_start)
self.put_strategy_event(strategy)
@ -870,7 +872,8 @@ class CtaEngine(BaseEngine):
"""
Init a strategy.
"""
self.init_executor.submit(self._init_strategy, strategy_name, auto_start)
task = self.thread_executor.submit(self._init_strategy, strategy_name, auto_start)
self.thread_tasks.append(task)
def _init_strategy(self, strategy_name: str, auto_start: bool = False):
"""
@ -959,9 +962,12 @@ class CtaEngine(BaseEngine):
风险警示 该方法强行干预策略的配置
"""
strategy = self.strategies[strategy_name]
auto_init = setting.pop('auto_init', False)
auto_start = setting.pop('auto_start', False)
strategy.update_setting(setting)
self.update_strategy_setting(strategy_name, setting)
self.update_strategy_setting(strategy_name, setting, auto_init, auto_start)
self.put_strategy_event(strategy)
def remove_strategy(self, strategy_name: str):
@ -1012,7 +1018,13 @@ class CtaEngine(BaseEngine):
if strategy_name not in self.strategies or strategy_name not in self.strategy_setting:
self.write_error(f"{strategy_name}不在运行策略中,不能重启")
return False
old_strategy_config = copy(self.strategy_setting[strategy_name])
# 从本地配置文件中读取
if len(setting) == 0:
strategies_setting = load_json(self.setting_filename)
old_strategy_config = strategies_setting.get(strategy_name, {})
else:
old_strategy_config = copy(self.strategy_setting[strategy_name])
class_name = old_strategy_config.get('class_name')
if len(vt_symbol) == 0:
@ -1035,11 +1047,98 @@ class CtaEngine(BaseEngine):
self.add_strategy(class_name=class_name,
strategy_name=strategy_name,
vt_symbol=vt_symbol,
setting=setting)
setting=setting,
auto_init=old_strategy_config.get('auto_init', False),
auto_start=old_strategy_config.get('auto_start', False))
self.write_log(f'重新运行策略{strategy_name}执行完毕')
return True
def save_strategy_data(self, select_name: str):
""" save strategy data"""
has_executed = False
msg = ""
# 1.判断策略名称是否存在字典中
for strategy_name in list(self.strategies.keys()):
if select_name != 'ALL':
if strategy_name != select_name:
continue
# 2.提取策略
strategy = self.strategies.get(strategy_name, None)
if not strategy:
continue
# 3.判断策略是否运行
if strategy.inited and strategy.trading:
task = self.thread_executor.submit(self.thread_save_strategy_data, strategy_name)
self.thread_tasks.append(task)
msg += f'{strategy_name}执行保存数据\n'
has_executed = True
else:
self.write_log(f'{strategy_name}未初始化/未启动交易,不进行保存数据')
return has_executed, msg
def thread_save_strategy_data(self, strategy_name):
"""异步线程保存策略数据"""
strategy = self.strategies.get(strategy_name, None)
if strategy is None:
return
try:
# 保存策略数据
strategy.sync_data()
except Exception as ex:
self.write_error(u'保存策略{}数据异常:'.format(strategy_name, str(ex)))
self.write_error(traceback.format_exc())
def save_strategy_snapshot(self, select_name: str):
"""
保存策略K线切片数据
:param select_name:
:return:
"""
has_executed = False
msg = ""
# 1.判断策略名称是否存在字典中
for strategy_name in list(self.strategies.keys()):
if select_name != 'ALL':
if strategy_name != select_name:
continue
# 2.提取策略
strategy = self.strategies.get(strategy_name, None)
if not strategy:
continue
if not hasattr(strategy, 'get_klines_snapshot'):
continue
# 3.判断策略是否运行
if strategy.inited and strategy.trading:
task = self.thread_executor.submit(self.thread_save_strategy_snapshot, strategy_name)
self.thread_tasks.append(task)
msg += f'{strategy_name}执行保存K线切片\n'
has_executed = True
return has_executed, msg
def thread_save_strategy_snapshot(self, strategy_name):
"""异步线程保存策略切片"""
strategy = self.strategies.get(strategy_name, None)
if strategy is None:
return
try:
# 5.保存策略切片
snapshot = strategy.get_klines_snapshot()
if len(snapshot) == 0:
self.write_log(f'{strategy_name}返回得K线切片数据为空')
return
# 剩下工作:保存本地文件/数据库
except Exception as ex:
self.write_error(u'获取策略{}切片数据异常:'.format(strategy_name, str(ex)))
self.write_error(traceback.format_exc())
def load_strategy_class(self):
"""
Load strategy class from source code.
@ -1209,7 +1308,7 @@ class CtaEngine(BaseEngine):
# SPD合约
spd_vt_symbol = pos.get('vt_symbol', None)
if spd_vt_symbol is not None and spd_vt_symbol.endswith('SPD'):
spd_symbol,spd_exchange = extract_vt_symbol(spd_vt_symbol)
spd_symbol, spd_exchange = extract_vt_symbol(spd_vt_symbol)
spd_setting = self.main_engine.get_all_custom_contracts().get(spd_symbol, None)
if spd_setting is None:
@ -1225,7 +1324,7 @@ class CtaEngine(BaseEngine):
leg1_pos.update({'symbol': spd_setting.get('leg1_symbol')})
leg1_pos.update({'vt_symbol': spd_setting.get('leg1_symbol')})
leg1_pos.update({'direction': leg1_direction})
leg1_pos.update({'volume': spd_setting.get('leg1_ratio', 1)*spd_volume})
leg1_pos.update({'volume': spd_setting.get('leg1_ratio', 1) * spd_volume})
leg2_pos = {}
leg2_pos.update({'symbol': spd_setting.get('leg2_symbol')})
@ -1288,7 +1387,12 @@ class CtaEngine(BaseEngine):
Get parameters of a strategy.
"""
strategy = self.strategies[strategy_name]
return strategy.get_parameters()
strategy_config = self.strategy_setting.get(strategy_name, {})
d = {}
d.update({'auto_init': strategy_config.get('auto_init', False)})
d.update({'auto_start': strategy_config.get('auto_start', False)})
d.update(strategy.get_parameters())
return d
def init_all_strategies(self):
"""
@ -1328,7 +1432,8 @@ class CtaEngine(BaseEngine):
auto_start=strategy_config.get('auto_start', False)
)
def update_strategy_setting(self, strategy_name: str, setting: dict):
def update_strategy_setting(self, strategy_name: str, setting: dict, auto_init: bool = False,
auto_start: bool = False):
"""
Update setting file.
"""
@ -1339,8 +1444,8 @@ class CtaEngine(BaseEngine):
self.strategy_setting[strategy_name] = {
"class_name": strategy.__class__.__name__,
"vt_symbol": strategy.vt_symbol,
"auto_init": strategy_config.get('auto_init', False),
"auto_start": strategy_config.get('auto_start', False),
"auto_init": auto_init,
"auto_start": auto_start,
"setting": setting
}
save_json(self.setting_filename, self.strategy_setting)

View File

@ -897,7 +897,7 @@ class PortfolioTestingEngine(object):
"""保存策略数据"""
for strategy in self.strategies.values():
self.write_log(u'save strategy data')
strategy.saveData()
strategy.save_data()
def send_order(self,
strategy: CtaTemplate,
@ -2274,7 +2274,10 @@ class PortfolioTestingEngine(object):
d = OrderedDict()
try:
for k in trade_fields:
d[k] = getattr(trade, k, '')
if k in ['exchange', 'direction', 'offset']:
d[k] = getattr(trade, k).value
else:
d[k] = getattr(trade, k, '')
trade_file = os.path.abspath(os.path.join(self.get_logs_path(), '{}_trade.csv'.format(strategy_name)))
self.append_data(file_name=trade_file, dict_data=d)

View File

@ -526,6 +526,8 @@ class CtaProTemplate(CtaTemplate):
增强模板
"""
backtesting = False
# 逻辑过程日志
dist_fieldnames = ['datetime', 'symbol', 'volume', 'price',
'operation', 'signal', 'stop_price', 'target_price',
@ -552,8 +554,6 @@ class CtaProTemplate(CtaTemplate):
self.cancel_seconds = 120 # 撤单时间(秒)
self.backtesting = False
self.klines = {} # K线字典: kline_name: kline
# 增加仓位管理模块
@ -601,7 +601,11 @@ class CtaProTemplate(CtaTemplate):
with bz2.BZ2File(file_name, 'wb') as f:
klines = {}
for kline_name in kline_names:
klines.update({kline_name: self.klines.get(kline_name, None)})
kline = self.klines.get(kline_name, None)
#if kline:
# kline.strategy = None
# kline.cb_on_bar = None
klines.update({kline_name: kline})
pickle.dump(klines, f)
def load_klines_from_cache(self, kline_names: list = []):
@ -657,11 +661,11 @@ class CtaProTemplate(CtaTemplate):
'strategy': self.strategy_name,
'datetime': datetime.now()}
for kline_name in self.klines.keys():
for kline_name in sorted(self.klines.keys()):
d.update({kline_name: self.klines.get(kline_name).get_data()})
return d
except Exception as ex:
self.write_error(u'获取klines切片数据失败')
self.write_error(f'获取klines切片数据失败:{str(ex)}')
return {}
def init_position(self):
@ -959,8 +963,9 @@ class CtaProTemplate(CtaTemplate):
else:
save_path = self.cta_engine.get_data_path()
try:
if self.position:
if self.position and 'long_pos' not in dist_data:
dist_data.update({'long_pos': self.position.long_pos})
if self.position and 'short_pos' not in dist_data:
dist_data.update({'short_pos': self.position.short_pos})
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_dist.csv'))

View File

@ -144,9 +144,10 @@ class CtaManager(QtWidgets.QWidget):
setting = editor.get_setting()
vt_symbol = setting.pop("vt_symbol")
strategy_name = setting.pop("strategy_name")
auto_init = setting.pop("auto_init", False)
auto_start = setting.pop("auto_start", False)
self.cta_engine.add_strategy(
class_name, strategy_name, vt_symbol, setting
class_name, strategy_name, vt_symbol, setting, auto_init, auto_start
)
def clear_log(self):
@ -201,6 +202,9 @@ class StrategyManager(QtWidgets.QFrame):
reload_button = QtWidgets.QPushButton("重载")
reload_button.clicked.connect(self.reload_strategy)
save_button = QtWidgets.QPushButton("保存")
save_button.clicked.connect(self.save_strategy)
strategy_name = self._data["strategy_name"]
vt_symbol = self._data["vt_symbol"]
class_name = self._data["class_name"]
@ -222,6 +226,7 @@ class StrategyManager(QtWidgets.QFrame):
hbox.addWidget(edit_button)
hbox.addWidget(remove_button)
hbox.addWidget(reload_button)
hbox.addWidget(save_button)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(label)
@ -273,6 +278,9 @@ class StrategyManager(QtWidgets.QFrame):
"""重新加载策略"""
self.cta_engine.reload_strategy(self.strategy_name)
def save_strategy(self):
self.cta_engine.save_strategy_data(self.strategy_name)
class DataMonitor(QtWidgets.QTableWidget):
"""
@ -403,8 +411,9 @@ class SettingEditor(QtWidgets.QDialog):
if self.class_name:
self.setWindowTitle(f"添加策略:{self.class_name}")
button_text = "添加"
parameters = {"strategy_name": "", "vt_symbol": ""}
parameters = {"strategy_name": "", "vt_symbol": "", "auto_init": True, "auto_start": True}
parameters.update(self.parameters)
else:
self.setWindowTitle(f"参数编辑:{self.strategy_name}")
button_text = "确定"

View File

@ -1219,3 +1219,59 @@ def get_file_logger(filename: str):
handler.setFormatter(log_formatter)
logger.addHandler(handler) # each handler will be added only once.
return logger
def get_bars(csv_file: str,
symbol: str,
exchange: Exchange,
start_date: datetime = None,
end_date: datetime = None,):
"""
获取bar
数据存储目录: 项目/bar_data
:param csv_file: csv文件路径
:param symbol: 合约
:param exchange 交易所
:param start_date: datetime
:param end_date: datetime
:return:
"""
bars = []
import csv
with open(file=csv_file, mode='r', encoding='utf8', newline='\n') as f:
reader = csv.DictReader(f)
count = 0
for item in reader:
dt = datetime.strptime(item['datetime'], '%Y-%m-%d %H:%M:%S')
if start_date:
if dt < start_date:
continue
if end_date:
if dt > end_date:
break
bar = BarData(
symbol=symbol,
exchange=exchange,
datetime=dt,
interval=Interval.MINUTE,
volume=float(item['volume']),
open_price=float(item['open']),
high_price=float(item['high']),
low_price=float(item['low']),
close_price=float(item['close']),
open_interest=float(item['open_interest']),
trading_day=item['trading_day'],
gateway_name="Tdx",
)
bars.append(bar)
# do some statistics
count += 1
return bars