[bug fix]
This commit is contained in:
parent
58f388e67b
commit
c2ed8bfba1
@ -524,7 +524,9 @@ class RestClient(object):
|
||||
client=self,
|
||||
)
|
||||
request = self.sign(request)
|
||||
|
||||
if request.path.startswith('http'):
|
||||
url = request.path
|
||||
else:
|
||||
url = self.make_full_url(request.path)
|
||||
|
||||
response = requests.request(
|
||||
|
@ -1101,21 +1101,23 @@ class CtaStockTemplate(CtaTemplate):
|
||||
acc_symbol_pos = self.cta_engine.get_position(
|
||||
vt_symbol=ordering_grid.vt_symbol,
|
||||
direction=Direction.NET)
|
||||
if acc_symbol_pos is None:
|
||||
self.write_error(f'{self.strategy_name}当前{ordering_grid.vt_symbol}持仓查询不到, 无法执行卖出')
|
||||
continue
|
||||
|
||||
vt_symbol = ordering_grid.vt_symbol
|
||||
cn_name = self.cta_engine.get_name(ordering_grid.vt_symbol)
|
||||
sell_volume = ordering_grid.volume - ordering_grid.traded_volume
|
||||
|
||||
if acc_symbol_pos is None:
|
||||
self.write_error(f'{self.strategy_name}当前{vt_symbol}[{cn_name}]持仓查询不到, 无法执行卖出[{sell_volume}]')
|
||||
continue
|
||||
|
||||
if sell_volume > acc_symbol_pos.volume:
|
||||
if not force:
|
||||
self.write_error(u'账号{}持仓{},不满足减仓目标:{}'
|
||||
.format(vt_symbol, acc_symbol_pos.volume, sell_volume))
|
||||
self.write_error(u'账号{}[{}]持仓{},不满足减仓目标:{}'
|
||||
.format(vt_symbol,cn_name, acc_symbol_pos.volume, sell_volume))
|
||||
continue
|
||||
else:
|
||||
self.write_log(u'账号{}持仓{},不满足减仓目标:{}, 修正卖出数量:{}=>{}'
|
||||
.format(vt_symbol, acc_symbol_pos.volume, sell_volume, sell_volume,
|
||||
self.write_log(u'账号{}[{}]持仓{},不满足减仓目标:{}, 修正卖出数量:{}=>{}'
|
||||
.format(vt_symbol,cn_name, acc_symbol_pos.volume, sell_volume, sell_volume,
|
||||
acc_symbol_pos.volume))
|
||||
sell_volume = acc_symbol_pos.volume
|
||||
|
||||
|
@ -6965,7 +6965,7 @@ class CtaMinuteBar(CtaLineBar):
|
||||
|
||||
# 更新最后价格
|
||||
self.cur_price = bar.close_price
|
||||
self.cur_datetime = bar.datetime
|
||||
self.cur_datetime = bar.datetime + timedelta(minutes=bar_freq)
|
||||
|
||||
bar_len = len(self.line_bar)
|
||||
|
||||
|
@ -70,7 +70,8 @@ POSITION_DIRECTION_XTP2VT = {
|
||||
# 委托单类型
|
||||
ORDERTYPE_XTP2VT: Dict[int, OrderType] = {
|
||||
1: OrderType.LIMIT,
|
||||
2: OrderType.MARKET
|
||||
2: OrderType.MARKET,
|
||||
4: OrderType.MARKET
|
||||
}
|
||||
ORDERTYPE_VT2XTP: Dict[OrderType, int] = {v: k for k, v in ORDERTYPE_XTP2VT.items()}
|
||||
|
||||
@ -550,7 +551,7 @@ class XtpTdApi(TdApi):
|
||||
exchange=MARKET_XTP2VT[data["market"]],
|
||||
orderid=str(data["order_xtp_id"]),
|
||||
sys_orderid=str(data["order_xtp_id"]),
|
||||
type=ORDERTYPE_XTP2VT[data["price_type"]],
|
||||
type=ORDERTYPE_XTP2VT.get(data["price_type"], OrderType.LIMIT),
|
||||
direction=direction,
|
||||
offset=offset,
|
||||
price=data["price"],
|
||||
|
@ -333,7 +333,7 @@ class ContractData(BaseData):
|
||||
net_position: bool = False # whether gateway uses net position volume
|
||||
history_data: bool = False # whether gateway provides bar history data
|
||||
|
||||
option_strike: float = 0
|
||||
option_strike: float = 0 # 行权价
|
||||
option_underlying: str = "" # vt_symbol of underlying contract
|
||||
option_type: OptionType = None
|
||||
option_expiry: datetime = None
|
||||
|
@ -1536,6 +1536,9 @@ class GridKline(QtWidgets.QWidget):
|
||||
# 配置项3:sub_indicators, 副图指标
|
||||
# 指标变量必须在data_file文件中存在字段
|
||||
|
||||
# 配置项目: trade_symbol_filters,交易记录过滤,缺省[]时不执行过滤
|
||||
# 根据交易记录中得symbol字段内容进行过滤,满足过滤条件得交易记录才使用。
|
||||
|
||||
# 配置项4:trade_list_file,开平仓交易记录
|
||||
# 每条记录包含开仓,平仓,收益信息
|
||||
# 回测时,每个策略实例,都产生trade_list.csv文件
|
||||
@ -1690,6 +1693,8 @@ class GridKline(QtWidgets.QWidget):
|
||||
main_indicators=kline_setting.get('main_indicators', []),
|
||||
sub_indicators=kline_setting.get('sub_indicators', [])
|
||||
)
|
||||
# 交易记录过滤
|
||||
trade_symbol_filters = kline_setting.get('trade_symbol_filters', [])
|
||||
|
||||
# 加载开、平仓的交易信号(一般是回测系统产生的)
|
||||
trade_list_file = kline_setting.get('trade_list_file', None)
|
||||
@ -1697,6 +1702,11 @@ class GridKline(QtWidgets.QWidget):
|
||||
print(f'loading {trade_list_file}')
|
||||
t1 = datetime.now()
|
||||
df_trade_list = pd.read_csv(trade_list_file)
|
||||
|
||||
# 如果需要过滤记录,过滤vt_symbol这个字段
|
||||
if len(trade_symbol_filters) > 0 and 'vt_symbol' in df_trade_list.columns:
|
||||
df_trade_list = df_trade_list[df_trade_list.vt_symbol.str.contains("|".join(trade_symbol_filters)).any(level=0)]
|
||||
|
||||
self.kline_dict[kline_name].add_signals(df_trade_list)
|
||||
t2 = datetime.now()
|
||||
s = (t2-t1).microseconds
|
||||
@ -1708,6 +1718,10 @@ class GridKline(QtWidgets.QWidget):
|
||||
print(f'loading {trade_file}')
|
||||
t1 = datetime.now()
|
||||
df_trade = pd.read_csv(trade_file)
|
||||
# 如果需要过滤记录,过滤vt_symbol这个字段
|
||||
if len(trade_symbol_filters)> 0 and 'vt_symbol' in df_trade.columns:
|
||||
df_trade = df_trade[df_trade.vt_symbol.str.contains("|".join(trade_symbol_filters)).any(level=0)]
|
||||
|
||||
t2 = datetime.now()
|
||||
s = (t2 - t1).microseconds
|
||||
print(f'finished load in {s} ms')
|
||||
|
@ -31,7 +31,9 @@ class UiSnapshot(object):
|
||||
dist_file: str = "",
|
||||
dist_include_list=[],
|
||||
use_grid=True,
|
||||
export_file=""):
|
||||
export_file="",
|
||||
kline_filters=[],
|
||||
symbol_filters=[]):
|
||||
"""
|
||||
显示切片
|
||||
:param snapshot_file: 切片文件路径(通过这个方法,可读取历史切片)
|
||||
@ -40,6 +42,9 @@ class UiSnapshot(object):
|
||||
:param dist_file: 格式化策略逻辑日志文件
|
||||
:param dist_include_list: 逻辑日志中,operation字段内需要过滤显示的内容
|
||||
:param use_grid: 使用同一窗口
|
||||
:param export_file: 界面生成图片得文件
|
||||
:param kline_filters: 缺省为空,即全部K线,如果约定只显示得K线,则在里面放入过滤条件,满足过滤条件得即可通过
|
||||
:param symbol_filters: 交易记录数据过滤,缺省为空不过滤。根据交易记录里面得symbol进行过滤。一般用于单策略,多合约交易时过滤交易记录
|
||||
:return:
|
||||
"""
|
||||
if d is None:
|
||||
@ -60,6 +65,12 @@ class UiSnapshot(object):
|
||||
|
||||
kline_settings = {}
|
||||
for k, v in klines.items():
|
||||
|
||||
# 如果存在k线名称过滤
|
||||
if len(kline_filters) > 0:
|
||||
if not any([name in k for name in kline_filters]):
|
||||
continue
|
||||
|
||||
# 获取bar各种数据/指标列表
|
||||
data_list = v.pop('data_list', None)
|
||||
if data_list is None:
|
||||
@ -84,6 +95,10 @@ class UiSnapshot(object):
|
||||
if len(trade_file) > 0 and os.path.exists(trade_file):
|
||||
setting.update({"trade_file": trade_file})
|
||||
|
||||
# 交易合约过滤
|
||||
if len(symbol_filters) > 0:
|
||||
setting.update({'trade_symbol_filters': symbol_filters})
|
||||
|
||||
# 加载本地data目录的事务
|
||||
if len(tns_file) > 0 and os.path.exists(tns_file):
|
||||
setting.update({"tns_file": tns_file})
|
||||
|
@ -337,7 +337,8 @@ def get_digits(value: float) -> int:
|
||||
|
||||
def print_dict(d: dict):
|
||||
"""返回dict的字符串类型"""
|
||||
return '\n'.join([f'{key}:{d[key]}' for key in sorted(d.keys())])
|
||||
max_key_len = max([len(str(k)) for k in d.keys()])
|
||||
return '\n'.join([str(key) + (max_key_len-len(str(key))) * " " + f': {d[key]}' for key in sorted(d.keys())])
|
||||
|
||||
|
||||
def get_csv_last_dt(file_name, dt_index=0, dt_format='%Y-%m-%d %H:%M:%S', line_length=1000):
|
||||
|
Loading…
Reference in New Issue
Block a user