[update]
This commit is contained in:
parent
b439a419e7
commit
7ce42e54a4
@ -210,6 +210,8 @@ class CtaLineBar(object):
|
||||
def init_param_list(self):
|
||||
self.paramList.append('bar_interval')
|
||||
self.paramList.append('interval')
|
||||
self.paramList.append('mode')
|
||||
|
||||
self.paramList.append('para_pre_len')
|
||||
self.paramList.append('para_ma1_len')
|
||||
self.paramList.append('para_ma2_len')
|
||||
|
@ -133,7 +133,8 @@ class CtaEngine(BaseEngine):
|
||||
self.stop_order_count = 0 # for generating stop_orderid
|
||||
self.stop_orders = {} # stop_orderid: stop_order
|
||||
|
||||
self.init_executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.thread_executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.thread_tasks = []
|
||||
|
||||
self.vt_tradeids = set() # for filtering duplicate trade
|
||||
|
||||
@ -722,7 +723,7 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
# 添加 策略名 strategy_name <=> 合约订阅 vt_symbol 的映射
|
||||
subscribe_symbol_set = self.strategy_symbol_map[strategy.strategy_name]
|
||||
subscribe_symbol_set.add(contract.vt_symbol)
|
||||
subscribe_symbol_set.add(vt_symbol)
|
||||
|
||||
return True
|
||||
|
||||
@ -777,7 +778,8 @@ class CtaEngine(BaseEngine):
|
||||
accounts = self.main_engine.get_all_accounts()
|
||||
if len(accounts) > 0:
|
||||
account = accounts[0]
|
||||
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01), 2), 100
|
||||
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01),
|
||||
2), 100
|
||||
else:
|
||||
return 0, 0, 0, 0
|
||||
|
||||
@ -858,7 +860,7 @@ class CtaEngine(BaseEngine):
|
||||
subscribe_symbol_set.add(vt_symbol)
|
||||
|
||||
# Update to setting file.
|
||||
self.update_strategy_setting(strategy_name, setting)
|
||||
self.update_strategy_setting(strategy_name, setting, auto_init, auto_start)
|
||||
|
||||
self.put_strategy_event(strategy)
|
||||
|
||||
@ -870,7 +872,8 @@ class CtaEngine(BaseEngine):
|
||||
"""
|
||||
Init a strategy.
|
||||
"""
|
||||
self.init_executor.submit(self._init_strategy, strategy_name, auto_start)
|
||||
task = self.thread_executor.submit(self._init_strategy, strategy_name, auto_start)
|
||||
self.thread_tasks.append(task)
|
||||
|
||||
def _init_strategy(self, strategy_name: str, auto_start: bool = False):
|
||||
"""
|
||||
@ -959,9 +962,12 @@ class CtaEngine(BaseEngine):
|
||||
风险警示: 该方法强行干预策略的配置
|
||||
"""
|
||||
strategy = self.strategies[strategy_name]
|
||||
auto_init = setting.pop('auto_init', False)
|
||||
auto_start = setting.pop('auto_start', False)
|
||||
|
||||
strategy.update_setting(setting)
|
||||
|
||||
self.update_strategy_setting(strategy_name, setting)
|
||||
self.update_strategy_setting(strategy_name, setting, auto_init, auto_start)
|
||||
self.put_strategy_event(strategy)
|
||||
|
||||
def remove_strategy(self, strategy_name: str):
|
||||
@ -1012,6 +1018,12 @@ class CtaEngine(BaseEngine):
|
||||
if strategy_name not in self.strategies or strategy_name not in self.strategy_setting:
|
||||
self.write_error(f"{strategy_name}不在运行策略中,不能重启")
|
||||
return False
|
||||
|
||||
# 从本地配置文件中读取
|
||||
if len(setting) == 0:
|
||||
strategies_setting = load_json(self.setting_filename)
|
||||
old_strategy_config = strategies_setting.get(strategy_name, {})
|
||||
else:
|
||||
old_strategy_config = copy(self.strategy_setting[strategy_name])
|
||||
|
||||
class_name = old_strategy_config.get('class_name')
|
||||
@ -1035,11 +1047,98 @@ class CtaEngine(BaseEngine):
|
||||
self.add_strategy(class_name=class_name,
|
||||
strategy_name=strategy_name,
|
||||
vt_symbol=vt_symbol,
|
||||
setting=setting)
|
||||
setting=setting,
|
||||
auto_init=old_strategy_config.get('auto_init', False),
|
||||
auto_start=old_strategy_config.get('auto_start', False))
|
||||
|
||||
self.write_log(f'重新运行策略{strategy_name}执行完毕')
|
||||
return True
|
||||
|
||||
def save_strategy_data(self, select_name: str):
|
||||
""" save strategy data"""
|
||||
has_executed = False
|
||||
msg = ""
|
||||
# 1.判断策略名称是否存在字典中
|
||||
for strategy_name in list(self.strategies.keys()):
|
||||
if select_name != 'ALL':
|
||||
if strategy_name != select_name:
|
||||
continue
|
||||
# 2.提取策略
|
||||
strategy = self.strategies.get(strategy_name, None)
|
||||
if not strategy:
|
||||
continue
|
||||
|
||||
# 3.判断策略是否运行
|
||||
if strategy.inited and strategy.trading:
|
||||
task = self.thread_executor.submit(self.thread_save_strategy_data, strategy_name)
|
||||
self.thread_tasks.append(task)
|
||||
msg += f'{strategy_name}执行保存数据\n'
|
||||
has_executed = True
|
||||
else:
|
||||
self.write_log(f'{strategy_name}未初始化/未启动交易,不进行保存数据')
|
||||
return has_executed, msg
|
||||
|
||||
def thread_save_strategy_data(self, strategy_name):
|
||||
"""异步线程保存策略数据"""
|
||||
strategy = self.strategies.get(strategy_name, None)
|
||||
if strategy is None:
|
||||
return
|
||||
try:
|
||||
# 保存策略数据
|
||||
strategy.sync_data()
|
||||
except Exception as ex:
|
||||
self.write_error(u'保存策略{}数据异常:'.format(strategy_name, str(ex)))
|
||||
self.write_error(traceback.format_exc())
|
||||
|
||||
def save_strategy_snapshot(self, select_name: str):
|
||||
"""
|
||||
保存策略K线切片数据
|
||||
:param select_name:
|
||||
:return:
|
||||
"""
|
||||
has_executed = False
|
||||
msg = ""
|
||||
# 1.判断策略名称是否存在字典中
|
||||
for strategy_name in list(self.strategies.keys()):
|
||||
if select_name != 'ALL':
|
||||
if strategy_name != select_name:
|
||||
continue
|
||||
# 2.提取策略
|
||||
strategy = self.strategies.get(strategy_name, None)
|
||||
if not strategy:
|
||||
continue
|
||||
|
||||
if not hasattr(strategy, 'get_klines_snapshot'):
|
||||
continue
|
||||
|
||||
# 3.判断策略是否运行
|
||||
if strategy.inited and strategy.trading:
|
||||
task = self.thread_executor.submit(self.thread_save_strategy_snapshot, strategy_name)
|
||||
self.thread_tasks.append(task)
|
||||
msg += f'{strategy_name}执行保存K线切片\n'
|
||||
has_executed = True
|
||||
|
||||
return has_executed, msg
|
||||
|
||||
def thread_save_strategy_snapshot(self, strategy_name):
|
||||
"""异步线程保存策略切片"""
|
||||
strategy = self.strategies.get(strategy_name, None)
|
||||
if strategy is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# 5.保存策略切片
|
||||
snapshot = strategy.get_klines_snapshot()
|
||||
if len(snapshot) == 0:
|
||||
self.write_log(f'{strategy_name}返回得K线切片数据为空')
|
||||
return
|
||||
|
||||
# 剩下工作:保存本地文件/数据库
|
||||
|
||||
except Exception as ex:
|
||||
self.write_error(u'获取策略{}切片数据异常:'.format(strategy_name, str(ex)))
|
||||
self.write_error(traceback.format_exc())
|
||||
|
||||
def load_strategy_class(self):
|
||||
"""
|
||||
Load strategy class from source code.
|
||||
@ -1209,7 +1308,7 @@ class CtaEngine(BaseEngine):
|
||||
# SPD合约
|
||||
spd_vt_symbol = pos.get('vt_symbol', None)
|
||||
if spd_vt_symbol is not None and spd_vt_symbol.endswith('SPD'):
|
||||
spd_symbol,spd_exchange = extract_vt_symbol(spd_vt_symbol)
|
||||
spd_symbol, spd_exchange = extract_vt_symbol(spd_vt_symbol)
|
||||
spd_setting = self.main_engine.get_all_custom_contracts().get(spd_symbol, None)
|
||||
|
||||
if spd_setting is None:
|
||||
@ -1225,7 +1324,7 @@ class CtaEngine(BaseEngine):
|
||||
leg1_pos.update({'symbol': spd_setting.get('leg1_symbol')})
|
||||
leg1_pos.update({'vt_symbol': spd_setting.get('leg1_symbol')})
|
||||
leg1_pos.update({'direction': leg1_direction})
|
||||
leg1_pos.update({'volume': spd_setting.get('leg1_ratio', 1)*spd_volume})
|
||||
leg1_pos.update({'volume': spd_setting.get('leg1_ratio', 1) * spd_volume})
|
||||
|
||||
leg2_pos = {}
|
||||
leg2_pos.update({'symbol': spd_setting.get('leg2_symbol')})
|
||||
@ -1288,7 +1387,12 @@ class CtaEngine(BaseEngine):
|
||||
Get parameters of a strategy.
|
||||
"""
|
||||
strategy = self.strategies[strategy_name]
|
||||
return strategy.get_parameters()
|
||||
strategy_config = self.strategy_setting.get(strategy_name, {})
|
||||
d = {}
|
||||
d.update({'auto_init': strategy_config.get('auto_init', False)})
|
||||
d.update({'auto_start': strategy_config.get('auto_start', False)})
|
||||
d.update(strategy.get_parameters())
|
||||
return d
|
||||
|
||||
def init_all_strategies(self):
|
||||
"""
|
||||
@ -1328,7 +1432,8 @@ class CtaEngine(BaseEngine):
|
||||
auto_start=strategy_config.get('auto_start', False)
|
||||
)
|
||||
|
||||
def update_strategy_setting(self, strategy_name: str, setting: dict):
|
||||
def update_strategy_setting(self, strategy_name: str, setting: dict, auto_init: bool = False,
|
||||
auto_start: bool = False):
|
||||
"""
|
||||
Update setting file.
|
||||
"""
|
||||
@ -1339,8 +1444,8 @@ class CtaEngine(BaseEngine):
|
||||
self.strategy_setting[strategy_name] = {
|
||||
"class_name": strategy.__class__.__name__,
|
||||
"vt_symbol": strategy.vt_symbol,
|
||||
"auto_init": strategy_config.get('auto_init', False),
|
||||
"auto_start": strategy_config.get('auto_start', False),
|
||||
"auto_init": auto_init,
|
||||
"auto_start": auto_start,
|
||||
"setting": setting
|
||||
}
|
||||
save_json(self.setting_filename, self.strategy_setting)
|
||||
|
@ -897,7 +897,7 @@ class PortfolioTestingEngine(object):
|
||||
"""保存策略数据"""
|
||||
for strategy in self.strategies.values():
|
||||
self.write_log(u'save strategy data')
|
||||
strategy.saveData()
|
||||
strategy.save_data()
|
||||
|
||||
def send_order(self,
|
||||
strategy: CtaTemplate,
|
||||
@ -2274,6 +2274,9 @@ class PortfolioTestingEngine(object):
|
||||
d = OrderedDict()
|
||||
try:
|
||||
for k in trade_fields:
|
||||
if k in ['exchange', 'direction', 'offset']:
|
||||
d[k] = getattr(trade, k).value
|
||||
else:
|
||||
d[k] = getattr(trade, k, '')
|
||||
|
||||
trade_file = os.path.abspath(os.path.join(self.get_logs_path(), '{}_trade.csv'.format(strategy_name)))
|
||||
|
@ -526,6 +526,8 @@ class CtaProTemplate(CtaTemplate):
|
||||
增强模板
|
||||
"""
|
||||
|
||||
backtesting = False
|
||||
|
||||
# 逻辑过程日志
|
||||
dist_fieldnames = ['datetime', 'symbol', 'volume', 'price',
|
||||
'operation', 'signal', 'stop_price', 'target_price',
|
||||
@ -552,8 +554,6 @@ class CtaProTemplate(CtaTemplate):
|
||||
|
||||
self.cancel_seconds = 120 # 撤单时间(秒)
|
||||
|
||||
self.backtesting = False
|
||||
|
||||
self.klines = {} # K线字典: kline_name: kline
|
||||
|
||||
# 增加仓位管理模块
|
||||
@ -601,7 +601,11 @@ class CtaProTemplate(CtaTemplate):
|
||||
with bz2.BZ2File(file_name, 'wb') as f:
|
||||
klines = {}
|
||||
for kline_name in kline_names:
|
||||
klines.update({kline_name: self.klines.get(kline_name, None)})
|
||||
kline = self.klines.get(kline_name, None)
|
||||
#if kline:
|
||||
# kline.strategy = None
|
||||
# kline.cb_on_bar = None
|
||||
klines.update({kline_name: kline})
|
||||
pickle.dump(klines, f)
|
||||
|
||||
def load_klines_from_cache(self, kline_names: list = []):
|
||||
@ -657,11 +661,11 @@ class CtaProTemplate(CtaTemplate):
|
||||
'strategy': self.strategy_name,
|
||||
'datetime': datetime.now()}
|
||||
|
||||
for kline_name in self.klines.keys():
|
||||
for kline_name in sorted(self.klines.keys()):
|
||||
d.update({kline_name: self.klines.get(kline_name).get_data()})
|
||||
return d
|
||||
except Exception as ex:
|
||||
self.write_error(u'获取klines切片数据失败')
|
||||
self.write_error(f'获取klines切片数据失败:{str(ex)}')
|
||||
return {}
|
||||
|
||||
def init_position(self):
|
||||
@ -959,8 +963,9 @@ class CtaProTemplate(CtaTemplate):
|
||||
else:
|
||||
save_path = self.cta_engine.get_data_path()
|
||||
try:
|
||||
if self.position:
|
||||
if self.position and 'long_pos' not in dist_data:
|
||||
dist_data.update({'long_pos': self.position.long_pos})
|
||||
if self.position and 'short_pos' not in dist_data:
|
||||
dist_data.update({'short_pos': self.position.short_pos})
|
||||
|
||||
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_dist.csv'))
|
||||
|
@ -144,9 +144,10 @@ class CtaManager(QtWidgets.QWidget):
|
||||
setting = editor.get_setting()
|
||||
vt_symbol = setting.pop("vt_symbol")
|
||||
strategy_name = setting.pop("strategy_name")
|
||||
|
||||
auto_init = setting.pop("auto_init", False)
|
||||
auto_start = setting.pop("auto_start", False)
|
||||
self.cta_engine.add_strategy(
|
||||
class_name, strategy_name, vt_symbol, setting
|
||||
class_name, strategy_name, vt_symbol, setting, auto_init, auto_start
|
||||
)
|
||||
|
||||
def clear_log(self):
|
||||
@ -201,6 +202,9 @@ class StrategyManager(QtWidgets.QFrame):
|
||||
reload_button = QtWidgets.QPushButton("重载")
|
||||
reload_button.clicked.connect(self.reload_strategy)
|
||||
|
||||
save_button = QtWidgets.QPushButton("保存")
|
||||
save_button.clicked.connect(self.save_strategy)
|
||||
|
||||
strategy_name = self._data["strategy_name"]
|
||||
vt_symbol = self._data["vt_symbol"]
|
||||
class_name = self._data["class_name"]
|
||||
@ -222,6 +226,7 @@ class StrategyManager(QtWidgets.QFrame):
|
||||
hbox.addWidget(edit_button)
|
||||
hbox.addWidget(remove_button)
|
||||
hbox.addWidget(reload_button)
|
||||
hbox.addWidget(save_button)
|
||||
|
||||
vbox = QtWidgets.QVBoxLayout()
|
||||
vbox.addWidget(label)
|
||||
@ -273,6 +278,9 @@ class StrategyManager(QtWidgets.QFrame):
|
||||
"""重新加载策略"""
|
||||
self.cta_engine.reload_strategy(self.strategy_name)
|
||||
|
||||
def save_strategy(self):
|
||||
self.cta_engine.save_strategy_data(self.strategy_name)
|
||||
|
||||
|
||||
class DataMonitor(QtWidgets.QTableWidget):
|
||||
"""
|
||||
@ -403,8 +411,9 @@ class SettingEditor(QtWidgets.QDialog):
|
||||
if self.class_name:
|
||||
self.setWindowTitle(f"添加策略:{self.class_name}")
|
||||
button_text = "添加"
|
||||
parameters = {"strategy_name": "", "vt_symbol": ""}
|
||||
parameters = {"strategy_name": "", "vt_symbol": "", "auto_init": True, "auto_start": True}
|
||||
parameters.update(self.parameters)
|
||||
|
||||
else:
|
||||
self.setWindowTitle(f"参数编辑:{self.strategy_name}")
|
||||
button_text = "确定"
|
||||
|
@ -1219,3 +1219,59 @@ def get_file_logger(filename: str):
|
||||
handler.setFormatter(log_formatter)
|
||||
logger.addHandler(handler) # each handler will be added only once.
|
||||
return logger
|
||||
|
||||
|
||||
def get_bars(csv_file: str,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
start_date: datetime = None,
|
||||
end_date: datetime = None,):
|
||||
"""
|
||||
获取bar
|
||||
数据存储目录: 项目/bar_data
|
||||
:param csv_file: csv文件路径
|
||||
:param symbol: 合约
|
||||
:param exchange 交易所
|
||||
:param start_date: datetime
|
||||
:param end_date: datetime
|
||||
:return:
|
||||
"""
|
||||
bars = []
|
||||
|
||||
import csv
|
||||
with open(file=csv_file, mode='r', encoding='utf8', newline='\n') as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
count = 0
|
||||
|
||||
for item in reader:
|
||||
|
||||
dt = datetime.strptime(item['datetime'], '%Y-%m-%d %H:%M:%S')
|
||||
if start_date:
|
||||
if dt < start_date:
|
||||
continue
|
||||
if end_date:
|
||||
if dt > end_date:
|
||||
break
|
||||
|
||||
bar = BarData(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
datetime=dt,
|
||||
interval=Interval.MINUTE,
|
||||
volume=float(item['volume']),
|
||||
open_price=float(item['open']),
|
||||
high_price=float(item['high']),
|
||||
low_price=float(item['low']),
|
||||
close_price=float(item['close']),
|
||||
open_interest=float(item['open_interest']),
|
||||
trading_day=item['trading_day'],
|
||||
gateway_name="Tdx",
|
||||
)
|
||||
|
||||
bars.append(bar)
|
||||
|
||||
# do some statistics
|
||||
count += 1
|
||||
|
||||
return bars
|
||||
|
Loading…
Reference in New Issue
Block a user