[bug fix]

This commit is contained in:
msincenselee 2021-11-25 10:24:14 +08:00
parent 58f388e67b
commit c2ed8bfba1
8 changed files with 50 additions and 15 deletions

View File

@ -524,7 +524,9 @@ class RestClient(object):
client=self, client=self,
) )
request = self.sign(request) request = self.sign(request)
if request.path.startswith('http'):
url = request.path
else:
url = self.make_full_url(request.path) url = self.make_full_url(request.path)
response = requests.request( response = requests.request(

View File

@ -1101,21 +1101,23 @@ class CtaStockTemplate(CtaTemplate):
acc_symbol_pos = self.cta_engine.get_position( acc_symbol_pos = self.cta_engine.get_position(
vt_symbol=ordering_grid.vt_symbol, vt_symbol=ordering_grid.vt_symbol,
direction=Direction.NET) direction=Direction.NET)
if acc_symbol_pos is None:
self.write_error(f'{self.strategy_name}当前{ordering_grid.vt_symbol}持仓查询不到, 无法执行卖出')
continue
vt_symbol = ordering_grid.vt_symbol vt_symbol = ordering_grid.vt_symbol
cn_name = self.cta_engine.get_name(ordering_grid.vt_symbol)
sell_volume = ordering_grid.volume - ordering_grid.traded_volume sell_volume = ordering_grid.volume - ordering_grid.traded_volume
if acc_symbol_pos is None:
self.write_error(f'{self.strategy_name}当前{vt_symbol}[{cn_name}]持仓查询不到, 无法执行卖出[{sell_volume}]')
continue
if sell_volume > acc_symbol_pos.volume: if sell_volume > acc_symbol_pos.volume:
if not force: if not force:
self.write_error(u'账号{}持仓{},不满足减仓目标:{}' self.write_error(u'账号{}[{}]持仓{},不满足减仓目标:{}'
.format(vt_symbol, acc_symbol_pos.volume, sell_volume)) .format(vt_symbol,cn_name, acc_symbol_pos.volume, sell_volume))
continue continue
else: else:
self.write_log(u'账号{}持仓{},不满足减仓目标:{}, 修正卖出数量:{}=>{}' self.write_log(u'账号{}[{}]持仓{},不满足减仓目标:{}, 修正卖出数量:{}=>{}'
.format(vt_symbol, acc_symbol_pos.volume, sell_volume, sell_volume, .format(vt_symbol,cn_name, acc_symbol_pos.volume, sell_volume, sell_volume,
acc_symbol_pos.volume)) acc_symbol_pos.volume))
sell_volume = acc_symbol_pos.volume sell_volume = acc_symbol_pos.volume

View File

@ -6965,7 +6965,7 @@ class CtaMinuteBar(CtaLineBar):
# 更新最后价格 # 更新最后价格
self.cur_price = bar.close_price self.cur_price = bar.close_price
self.cur_datetime = bar.datetime self.cur_datetime = bar.datetime + timedelta(minutes=bar_freq)
bar_len = len(self.line_bar) bar_len = len(self.line_bar)

View File

@ -70,7 +70,8 @@ POSITION_DIRECTION_XTP2VT = {
# 委托单类型 # 委托单类型
ORDERTYPE_XTP2VT: Dict[int, OrderType] = { ORDERTYPE_XTP2VT: Dict[int, OrderType] = {
1: OrderType.LIMIT, 1: OrderType.LIMIT,
2: OrderType.MARKET 2: OrderType.MARKET,
4: OrderType.MARKET
} }
ORDERTYPE_VT2XTP: Dict[OrderType, int] = {v: k for k, v in ORDERTYPE_XTP2VT.items()} ORDERTYPE_VT2XTP: Dict[OrderType, int] = {v: k for k, v in ORDERTYPE_XTP2VT.items()}
@ -550,7 +551,7 @@ class XtpTdApi(TdApi):
exchange=MARKET_XTP2VT[data["market"]], exchange=MARKET_XTP2VT[data["market"]],
orderid=str(data["order_xtp_id"]), orderid=str(data["order_xtp_id"]),
sys_orderid=str(data["order_xtp_id"]), sys_orderid=str(data["order_xtp_id"]),
type=ORDERTYPE_XTP2VT[data["price_type"]], type=ORDERTYPE_XTP2VT.get(data["price_type"], OrderType.LIMIT),
direction=direction, direction=direction,
offset=offset, offset=offset,
price=data["price"], price=data["price"],

View File

@ -333,7 +333,7 @@ class ContractData(BaseData):
net_position: bool = False # whether gateway uses net position volume net_position: bool = False # whether gateway uses net position volume
history_data: bool = False # whether gateway provides bar history data history_data: bool = False # whether gateway provides bar history data
option_strike: float = 0 option_strike: float = 0 # 行权价
option_underlying: str = "" # vt_symbol of underlying contract option_underlying: str = "" # vt_symbol of underlying contract
option_type: OptionType = None option_type: OptionType = None
option_expiry: datetime = None option_expiry: datetime = None

View File

@ -1536,6 +1536,9 @@ class GridKline(QtWidgets.QWidget):
# 配置项3sub_indicators 副图指标 # 配置项3sub_indicators 副图指标
# 指标变量必须在data_file文件中存在字段 # 指标变量必须在data_file文件中存在字段
# 配置项目: trade_symbol_filters,交易记录过滤,缺省[]时不执行过滤
# 根据交易记录中得symbol字段内容进行过滤满足过滤条件得交易记录才使用。
# 配置项4trade_list_file开平仓交易记录 # 配置项4trade_list_file开平仓交易记录
# 每条记录包含开仓,平仓,收益信息 # 每条记录包含开仓,平仓,收益信息
# 回测时每个策略实例都产生trade_list.csv文件 # 回测时每个策略实例都产生trade_list.csv文件
@ -1690,6 +1693,8 @@ class GridKline(QtWidgets.QWidget):
main_indicators=kline_setting.get('main_indicators', []), main_indicators=kline_setting.get('main_indicators', []),
sub_indicators=kline_setting.get('sub_indicators', []) sub_indicators=kline_setting.get('sub_indicators', [])
) )
# 交易记录过滤
trade_symbol_filters = kline_setting.get('trade_symbol_filters', [])
# 加载开、平仓的交易信号(一般是回测系统产生的) # 加载开、平仓的交易信号(一般是回测系统产生的)
trade_list_file = kline_setting.get('trade_list_file', None) trade_list_file = kline_setting.get('trade_list_file', None)
@ -1697,6 +1702,11 @@ class GridKline(QtWidgets.QWidget):
print(f'loading {trade_list_file}') print(f'loading {trade_list_file}')
t1 = datetime.now() t1 = datetime.now()
df_trade_list = pd.read_csv(trade_list_file) df_trade_list = pd.read_csv(trade_list_file)
# 如果需要过滤记录过滤vt_symbol这个字段
if len(trade_symbol_filters) > 0 and 'vt_symbol' in df_trade_list.columns:
df_trade_list = df_trade_list[df_trade_list.vt_symbol.str.contains("|".join(trade_symbol_filters)).any(level=0)]
self.kline_dict[kline_name].add_signals(df_trade_list) self.kline_dict[kline_name].add_signals(df_trade_list)
t2 = datetime.now() t2 = datetime.now()
s = (t2-t1).microseconds s = (t2-t1).microseconds
@ -1708,6 +1718,10 @@ class GridKline(QtWidgets.QWidget):
print(f'loading {trade_file}') print(f'loading {trade_file}')
t1 = datetime.now() t1 = datetime.now()
df_trade = pd.read_csv(trade_file) df_trade = pd.read_csv(trade_file)
# 如果需要过滤记录过滤vt_symbol这个字段
if len(trade_symbol_filters)> 0 and 'vt_symbol' in df_trade.columns:
df_trade = df_trade[df_trade.vt_symbol.str.contains("|".join(trade_symbol_filters)).any(level=0)]
t2 = datetime.now() t2 = datetime.now()
s = (t2 - t1).microseconds s = (t2 - t1).microseconds
print(f'finished load in {s} ms') print(f'finished load in {s} ms')

View File

@ -31,7 +31,9 @@ class UiSnapshot(object):
dist_file: str = "", dist_file: str = "",
dist_include_list=[], dist_include_list=[],
use_grid=True, use_grid=True,
export_file=""): export_file="",
kline_filters=[],
symbol_filters=[]):
""" """
显示切片 显示切片
:param snapshot_file: 切片文件路径通过这个方法可读取历史切片 :param snapshot_file: 切片文件路径通过这个方法可读取历史切片
@ -40,6 +42,9 @@ class UiSnapshot(object):
:param dist_file: 格式化策略逻辑日志文件 :param dist_file: 格式化策略逻辑日志文件
:param dist_include_list: 逻辑日志中operation字段内需要过滤显示的内容 :param dist_include_list: 逻辑日志中operation字段内需要过滤显示的内容
:param use_grid: 使用同一窗口 :param use_grid: 使用同一窗口
:param export_file: 界面生成图片得文件
:param kline_filters: 缺省为空即全部K线如果约定只显示得K线则在里面放入过滤条件满足过滤条件得即可通过
:param symbol_filters: 交易记录数据过滤缺省为空不过滤根据交易记录里面得symbol进行过滤一般用于单策略多合约交易时过滤交易记录
:return: :return:
""" """
if d is None: if d is None:
@ -60,6 +65,12 @@ class UiSnapshot(object):
kline_settings = {} kline_settings = {}
for k, v in klines.items(): for k, v in klines.items():
# 如果存在k线名称过滤
if len(kline_filters) > 0:
if not any([name in k for name in kline_filters]):
continue
# 获取bar各种数据/指标列表 # 获取bar各种数据/指标列表
data_list = v.pop('data_list', None) data_list = v.pop('data_list', None)
if data_list is None: if data_list is None:
@ -84,6 +95,10 @@ class UiSnapshot(object):
if len(trade_file) > 0 and os.path.exists(trade_file): if len(trade_file) > 0 and os.path.exists(trade_file):
setting.update({"trade_file": trade_file}) setting.update({"trade_file": trade_file})
# 交易合约过滤
if len(symbol_filters) > 0:
setting.update({'trade_symbol_filters': symbol_filters})
# 加载本地data目录的事务 # 加载本地data目录的事务
if len(tns_file) > 0 and os.path.exists(tns_file): if len(tns_file) > 0 and os.path.exists(tns_file):
setting.update({"tns_file": tns_file}) setting.update({"tns_file": tns_file})

View File

@ -337,7 +337,8 @@ def get_digits(value: float) -> int:
def print_dict(d: dict): def print_dict(d: dict):
"""返回dict的字符串类型""" """返回dict的字符串类型"""
return '\n'.join([f'{key}:{d[key]}' for key in sorted(d.keys())]) max_key_len = max([len(str(k)) for k in d.keys()])
return '\n'.join([str(key) + (max_key_len-len(str(key))) * " " + f': {d[key]}' for key in sorted(d.keys())])
def get_csv_last_dt(file_name, dt_index=0, dt_format='%Y-%m-%d %H:%M:%S', line_length=1000): def get_csv_last_dt(file_name, dt_index=0, dt_format='%Y-%m-%d %H:%M:%S', line_length=1000):