diff --git a/vnpy/app/cta_crypto/back_testing.py b/vnpy/app/cta_crypto/back_testing.py index 2e56b72d..c6757a5f 100644 --- a/vnpy/app/cta_crypto/back_testing.py +++ b/vnpy/app/cta_crypto/back_testing.py @@ -41,7 +41,8 @@ from vnpy.trader.object import ( TickData, OrderData, TradeData, - ContractData + ContractData, + PositionData ) from vnpy.trader.constant import ( Exchange, @@ -127,10 +128,6 @@ class BackTestingEngine(object): self.order_strategy_dict = {} # orderid 与 strategy的映射 - # 持仓缓存字典 - # key为vt_symbol,value为PositionBuffer对象 - self.pos_holding_dict = {} - self.trade_count = 0 # 成交编号 self.trade_dict = OrderedDict() # 用于统计成交收益时,还没处理得交易 self.trades = OrderedDict() # 记录所有得成交记录 @@ -139,7 +136,7 @@ class BackTestingEngine(object): self.long_position_list = [] # 多单持仓 self.short_position_list = [] # 空单持仓 - self.holdings = {} # 多空持仓 + self.positions = {} # 账号持仓,对象为PositionData # 当前最新数据,用于模拟成交用 self.gateway_name = u'BackTest' @@ -388,13 +385,13 @@ class BackTestingEngine(object): def get_exchange(self, symbol: str): return self.symbol_exchange_dict.get(symbol, Exchange.LOCAL) - def get_position_holding(self, vt_symbol: str, gateway_name: str = ''): - """ 查询合约在账号的持仓(包含多空)""" + def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''): + """ 查询合约在账号的持仓""" if not gateway_name: gateway_name = self.gateway_name - k = f'{gateway_name}.{vt_symbol}' - holding = self.holdings.get(k, None) - if not holding: + k = f'{gateway_name}.{vt_symbol}.{direction.value}' + pos = self.positions.get(k, None) + if not pos: contract = self.get_contract(vt_symbol) if not contract: self.write_log(f'{vt_symbol}合约信息不存在,构造一个') @@ -413,9 +410,14 @@ class BackTestingEngine(object): size=self.get_size(vt_symbol), pricetick=self.get_price_tick(vt_symbol), margin_rate=self.get_margin_rate(vt_symbol)) - holding = PositionHolding(contract) - self.holdings[k] = holding - return holding + pos = PositionData( + gateway_name=gateway_name, + symbol=contract.symbol, + exchange=contract.exchange, + direction=direction + ) + self.positions[k] = pos + return pos def set_name(self, test_name): """ @@ -1081,9 +1083,13 @@ class BackTestingEngine(object): self.append_trade(trade) # 更新持仓缓存数据 - k = '.'.join([self.gateway_name, trade.vt_symbol]) - holding = self.get_position_holding(trade.vt_symbol, self.gateway_name) - holding.update_trade(trade) + pos = self.get_position(vt_symbol=trade.vt_symbol,direction=Direction.NET) + pre_volume = pos.volume + if trade.direction == Direction.LONG: + pos.volume = round(pos.volume + trade.volume, 7) + else: + pos.volume = round(pos.volume - trade.volume, 7) + self.write_log(f'{trade.vt_symbol} volume:{pre_volume} => {pos.volume}') strategy.on_trade(trade) @@ -1160,22 +1166,27 @@ class BackTestingEngine(object): trade.strategy_name = strategy.strategy_name # 更新持仓缓存数据 - k = '.'.join([self.gateway_name, trade.vt_symbol]) - holding = self.get_position_holding(trade.vt_symbol, self.gateway_name) - holding.update_trade(trade) - strategy.on_trade(trade) + pos = self.get_position(vt_symbol=trade.vt_symbol, direction=Direction.NET) + pre_volume = pos.volume + if trade.direction == Direction.LONG: + pos.volume = round(pos.volume + trade.volume, 7) + else: + pos.volume = round(pos.volume - trade.volume, 7) + self.write_log(f'{trade.vt_symbol} volume:{pre_volume} => {pos.volume}') self.trade_dict[trade.vt_tradeid] = trade self.trades[trade.vt_tradeid] = copy.copy(trade) self.write_log(u'vt_trade_id:{0}'.format(trade.vt_tradeid)) - self.write_log(u'{} : crossLimitOrder: TradeId:{}, posBuffer = {}'.format(trade.strategy_name, + self.write_log(u'{} : crossLimitOrder: TradeId:{}'.format(trade.strategy_name, trade.tradeid, - holding.to_str())) + )) # 写入交易记录 self.append_trade(trade) + strategy.on_trade(trade) + # 更新资金曲线 fund_kline = self.get_fund_kline(trade.strategy_name) if fund_kline: @@ -1193,20 +1204,6 @@ class BackTestingEngine(object): # 实时计算模式 self.realtime_calculate() - def update_pos_buffer(self): - """更新持仓信息,把今仓=>昨仓""" - - for k, v in self.pos_holding_dict.items(): - if v.long_td > 0: - self.write_log(u'调整多单持仓:今仓{}=> 0 昨仓{} => 昨仓:{}'.format(v.long_td, v.long_yd, v.long_pos)) - v.long_td = 0 - v.longYd = v.long_pos - - if v.short_td > 0: - self.write_log(u'调整空单持仓:今仓{}=> 0 昨仓{} => 昨仓:{}'.format(v.short_td, v.short_yd, v.short_pos)) - v.short_td = 0 - v.short_yd = v.short_pos - def get_data_path(self): """ 获取数据保存目录 @@ -1842,19 +1839,6 @@ class BackTestingEngine(object): else: self.write_log(msg) - # 今仓 =》 昨仓 - for holding in self.holdings.values(): - if holding.long_td > 0: - self.write_log( - f'{holding.vt_symbol} 多单今仓{holding.long_td},昨仓:{holding.long_yd}=> 昨仓:{holding.long_pos}') - holding.long_td = 0 - holding.long_yd = holding.long_pos - if holding.short_td > 0: - self.write_log( - f'{holding.vt_symbol} 空单今仓{holding.short_td},昨仓:{holding.short_yd}=> 昨仓:{holding.short_pos}') - holding.short_td = 0 - holding.short_yd = holding.short_pos - # --------------------------------------------------------------------- def export_trade_result(self): """ diff --git a/vnpy/app/cta_crypto/engine.py b/vnpy/app/cta_crypto/engine.py index 9367d2f1..8f6ede37 100644 --- a/vnpy/app/cta_crypto/engine.py +++ b/vnpy/app/cta_crypto/engine.py @@ -138,7 +138,7 @@ class CtaEngine(BaseEngine): self.vt_tradeids = set() # for filtering duplicate trade - self.holdings = {} + self.positions = {} self.last_minute = None @@ -243,9 +243,6 @@ class CtaEngine(BaseEngine): """""" order = event.data - holding = self.get_position_holding(order.vt_symbol, order.gateway_name) - holding.update_order(order) - strategy = self.orderid_strategy_map.get(order.vt_orderid, None) if not strategy: return @@ -282,9 +279,6 @@ class CtaEngine(BaseEngine): return self.vt_tradeids.add(trade.vt_tradeid) - holding = self.get_position_holding(trade.vt_symbol, trade.gateway_name) - holding.update_trade(trade) - strategy = self.orderid_strategy_map.get(trade.vt_orderid, None) if not strategy: return @@ -341,8 +335,7 @@ class CtaEngine(BaseEngine): """""" position = event.data - holding = self.get_position_holding(position.vt_symbol, position.gateway_name) - holding.update_position(position) + self.positions.update({position.vt_positionid: position}) def check_unsubscribed_symbols(self): """检查未订阅合约""" @@ -786,7 +779,7 @@ class CtaEngine(BaseEngine): contract = self.main_engine.get_contract(vt_symbol) if contract is None: self.write_error(f'查询不到{vt_symbol}合约信息') - return 0.1 + return 0.001 return contract.pricetick @@ -796,7 +789,7 @@ class CtaEngine(BaseEngine): contract = self.main_engine.get_contract(vt_symbol) if contract is None: self.write_error(f'查询不到{vt_symbol}合约信息') - return 1 + return 0.01 return contract.min_volume @@ -832,22 +825,14 @@ class CtaEngine(BaseEngine): def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''): """ 查询合约在账号的持仓,需要指定方向""" + contract = self.main_engine.get_contract(vt_symbol) + if contract: + if contract.gateway_name and not gateway_name: + gateway_name = contract.gateway_name + vt_position_id = f"{gateway_name}.{vt_symbol}.{direction.value}" return self.main_engine.get_position(vt_position_id) - def get_position_holding(self, vt_symbol: str, gateway_name: str = ''): - """ 查询合约在账号的持仓(包含多空)""" - k = f'{gateway_name}.{vt_symbol}' - holding = self.holdings.get(k, None) - if not holding: - symbol, exchange = extract_vt_symbol(vt_symbol) - - contract = self.main_engine.get_contract(vt_symbol) - - holding = PositionHolding(contract) - self.holdings[k] = holding - return holding - def get_engine_type(self): """""" return self.engine_type @@ -1498,19 +1483,15 @@ class CtaEngine(BaseEngine): compare_pos = dict() # vt_symbol: {'账号多单': xx, '账号空单':xxx, '策略空单':[], '策略多单':[]} - for holding_key in list(self.holdings.keys()): + for position in list(self.positions.values()): # gateway_name.symbol.exchange => symbol.exchange - vt_symbol = '.'.join(holding_key.split('.')[-2:]) - + vt_symbol = position.vt_symbol vt_symbols.add(vt_symbol) - holding = self.holdings.get(holding_key, None) - if holding is None: - continue compare_pos[vt_symbol] = OrderedDict( { - "账号空单": holding.short_pos, - '账号多单': holding.long_pos, + "账号空单": abs(position.volume) if position.volume < 0 else 0, + '账号多单': position.volume if position.volume > 0 else 0, '策略空单': 0, '策略多单': 0, '空单策略': [], diff --git a/vnpy/app/cta_crypto/portfolio_testing.py b/vnpy/app/cta_crypto/portfolio_testing.py index e192115d..b0a367b6 100644 --- a/vnpy/app/cta_crypto/portfolio_testing.py +++ b/vnpy/app/cta_crypto/portfolio_testing.py @@ -279,8 +279,7 @@ class PortfolioTestingEngine(BackTestingEngine): # 第二个交易日,撤单 self.cancel_orders() - # 更新持仓缓存 - self.update_pos_buffer() + gc_collect_days += 1 if gc_collect_days >= 10: diff --git a/vnpy/app/cta_crypto/template.py b/vnpy/app/cta_crypto/template.py index 132d4478..57744f47 100644 --- a/vnpy/app/cta_crypto/template.py +++ b/vnpy/app/cta_crypto/template.py @@ -997,7 +997,9 @@ class CtaFutureTemplate(CtaTemplate): """ self.write_log(u'执行事务平多仓位:{}'.format(grid.to_json())) - self.account_pos = self.cta_engine.get_position_holding(self.vt_symbol) + self.account_pos = self.cta_engine.get_position( + vt_symbol=self.vt_symbol, + direction=Direction.NET) if self.account_pos is None: self.write_error(u'无法获取{}得持仓信息'.format(self.vt_symbol)) @@ -1050,7 +1052,9 @@ class CtaFutureTemplate(CtaTemplate): """ self.write_log(u'执行事务平空仓位:{}'.format(grid.to_json())) - self.account_pos = self.cta_engine.get_position_holding(self.vt_symbol) + self.account_pos = self.cta_engine.get_position( + vt_symbol=self.vt_symbol, + direction=Direction.NET) if self.account_pos is None: self.write_error(u'无法获取{}得持仓信息'.format(self.vt_symbol)) return False @@ -1064,17 +1068,21 @@ class CtaFutureTemplate(CtaTemplate): # 发出cover委托 if grid.traded_volume > 0: grid.volume -= grid.traded_volume + grid.volume = round(grid.volume, 7) grid.traded_volume = 0 - grid.volume = round(grid.volume, 7) - - if 0 < abs(self.account_pos.short_pos) < grid.volume: - self.write_error(u'当前{}的空单持仓:{},不满足平仓目标:{}, 强制降低' + if self.account_pos.volume >= 0: + self.write_error(u'当前{}的净持仓:{},不能平空单' .format(self.vt_symbol, - self.account_pos.short_pos, + self.account_pos.volume)) + return False + if abs(self.account_pos.volume) < grid.volume: + self.write_error(u'当前{}的净持仓:{},不满足平仓目标:{}, 强制降低' + .format(self.vt_symbol, + self.account_pos.volume, grid.volume)) - grid.volume = abs(self.account_pos.short_pos) + grid.volume = abs(self.account_pos.volume) vt_orderids = self.cover( price=cover_price, diff --git a/vnpy/component/cta_line_bar.py b/vnpy/component/cta_line_bar.py index d1d7279c..86be7891 100644 --- a/vnpy/component/cta_line_bar.py +++ b/vnpy/component/cta_line_bar.py @@ -77,11 +77,12 @@ class CtaLineBar(object): self.lineM = None # 1分钟K线 lineMSetting = {} lineMSetting['name'] = u'M1' + lineMSetting['interval'] = Interval.MINUTE lineMSetting['bar_interval'] = 60 # 1分钟对应60秒 - lineMSetting['inputEma1Len'] = 7 # EMA线1的周期 - lineMSetting['inputEma2Len'] = 21 # EMA线2的周期 - lineMSetting['inputBollLen'] = 20 # 布林特线周期 - lineMSetting['inputBollStdRate'] = 2 # 布林特线标准差 + lineMSetting['para_ema1_len'] = 7 # EMA线1的周期 + lineMSetting['para_ema2_len'] = 21 # EMA线2的周期 + lineMSetting['para_boll_len'] = 20 # 布林特线周期 + lineMSetting['para_boll_std_rate'] = 2 # 布林特线标准差 lineMSetting['price_tick'] = self.price_tick # 最小条 lineMSetting['underlying_symbol'] = self.underlying_symbol #商品短号 self.lineM = CtaLineBar(self, self.onBar, lineMSetting) @@ -207,7 +208,7 @@ class CtaLineBar(object): # 修正精度 if self.price_tick < 1: exponent = decimal.Decimal(str(self.price_tick)) - self.round_n = max(abs(exponent.as_tuple().exponent), 4) + self.round_n = max(abs(exponent.as_tuple().exponent) + 2, 4) # 导入卡尔曼过滤器 if self.para_active_kf: diff --git a/vnpy/gateway/binancef/binancef_gateway.py b/vnpy/gateway/binancef/binancef_gateway.py index ca179521..fc555529 100644 --- a/vnpy/gateway/binancef/binancef_gateway.py +++ b/vnpy/gateway/binancef/binancef_gateway.py @@ -494,10 +494,13 @@ class BinancefRestApi(RestClient): for asset in data["assets"]: account = AccountData( accountid=asset["asset"], - balance=float(asset["walletBalance"]), + balance=float(asset["walletBalance"]) + float(asset["maintMargin"]), frozen=float(asset["maintMargin"]), + holding_profit=float(asset['unrealizedProfit']), gateway_name=self.gateway_name ) + # 修正vnpy AccountData + account.balance += account.holding_profit if account.balance: self.gateway.on_account(account) @@ -516,69 +519,18 @@ class BinancefRestApi(RestClient): """""" for d in data: volume = float(d["positionAmt"]) - - if volume > 0: - long_position = PositionData( - symbol=d["symbol"], - exchange=Exchange.BINANCE, - direction=Direction.LONG, - volume=abs(volume), - price=float(d["entryPrice"]), - pnl=float(d["unRealizedProfit"]), - gateway_name=self.gateway_name, - ) - short_position = PositionData( - symbol=d["symbol"], - exchange=Exchange.BINANCE, - direction=Direction.SHORT, - volume=0, - price=0, - pnl=0, - gateway_name=self.gateway_name, - ) - elif volume < 0: - long_position = PositionData( - symbol=d["symbol"], - exchange=Exchange.BINANCE, - direction=Direction.LONG, - volume=0, - price=0, - pnl=0, - gateway_name=self.gateway_name, - ) - short_position = PositionData( - symbol=d["symbol"], - exchange=Exchange.BINANCE, - direction=Direction.SHORT, - volume=abs(volume), - price=float(d["entryPrice"]), - pnl=float(d["unRealizedProfit"]), - gateway_name=self.gateway_name, - ) - else: - long_position = PositionData( - symbol=d["symbol"], - exchange=Exchange.BINANCE, - direction=Direction.LONG, - volume=0, - price=0, - pnl=0, - gateway_name=self.gateway_name, - ) - short_position = PositionData( - symbol=d["symbol"], - exchange=Exchange.BINANCE, - direction=Direction.SHORT, - volume=0, - price=0, - pnl=0, - gateway_name=self.gateway_name, - ) - - self.gateway.on_position(long_position) - self.gateway.on_position(short_position) - - # self.gateway.write_log("持仓信息查询成功") + position = PositionData( + symbol=d["symbol"], + exchange=Exchange.BINANCE, + direction=Direction.NET, + volume=volume, + price=float(d["entryPrice"]), + pnl=float(d["unRealizedProfit"]), + gateway_name=self.gateway_name, + ) + self.gateway.on_position(position) + + # self.gateway.write_log("持仓信息查询成功") def on_query_order(self, data: dict, request: Request) -> None: """""" @@ -831,15 +783,34 @@ class BinancefTradeWebsocketApi(WebsocketClient): def on_account(self, packet: dict) -> None: """""" + holding_pnl = 0 + for pos_data in packet["a"]["P"]: + print(pos_data) + volume = float(pos_data["pa"]) + position = PositionData( + symbol=pos_data["s"], + exchange=Exchange.BINANCE, + direction=Direction.NET, + volume=abs(volume), + price=float(pos_data["ep"]), + pnl=float(pos_data["cr"]), + gateway_name=self.gateway_name, + ) + holding_pnl += float(pos_data['up']) + self.gateway.on_position(position) + for acc_data in packet["a"]["B"]: account = AccountData( accountid=acc_data["a"], - balance=float(acc_data["wb"]), + balance=round(float(acc_data["wb"]), 7), frozen=float(acc_data["wb"]) - float(acc_data["cw"]), + holding_profit=round(holding_pnl, 7), gateway_name=self.gateway_name ) if account.balance: + account.balance += account.holding_profit + account.available = float(acc_data["cw"]) self.gateway.on_account(account) for pos_data in packet["a"]["P"]: diff --git a/vnpy/trader/converter.py b/vnpy/trader/converter.py index 761a3735..e9709286 100644 --- a/vnpy/trader/converter.py +++ b/vnpy/trader/converter.py @@ -130,11 +130,11 @@ class PositionHolding: if position.direction == Direction.LONG: self.long_pos = position.volume self.long_yd = position.yd_volume - self.long_td = self.long_pos - self.long_yd + self.long_td = round(self.long_pos - self.long_yd, 7) else: self.short_pos = position.volume self.short_yd = position.yd_volume - self.short_td = self.short_pos - self.short_yd + self.short_td = round(self.short_pos - self.short_yd, 7) def update_order(self, order: OrderData) -> None: """""" @@ -211,7 +211,7 @@ class PositionHolding: if order.offset == Offset.OPEN: continue - frozen = order.volume - order.traded + frozen = round(order.volume - order.traded, 7) if order.direction == Direction.LONG: if order.offset == Offset.CLOSETODAY: @@ -238,8 +238,8 @@ class PositionHolding: - self.long_td) self.long_td_frozen = self.long_td - self.long_pos_frozen = self.long_td_frozen + self.long_yd_frozen - self.short_pos_frozen = self.short_td_frozen + self.short_yd_frozen + self.long_pos_frozen = round(self.long_td_frozen + self.long_yd_frozen, 7) + self.short_pos_frozen = round(self.short_td_frozen + self.short_yd_frozen, 7) def convert_order_request_shfe(self, req: OrderRequest) -> List[OrderRequest]: """上期所,委托单拆分""" diff --git a/vnpy/trader/ui/widget.py b/vnpy/trader/ui/widget.py index 2c8bbee1..08a71c82 100644 --- a/vnpy/trader/ui/widget.py +++ b/vnpy/trader/ui/widget.py @@ -750,7 +750,7 @@ class TradingWidget(QtWidgets.QWidget): self.return_label.setText(f"{r:.2f}%") if tick.bid_price_2: - self.bp2_label.setText(str(round(tick.bid_price_2), 7)) + self.bp2_label.setText(str(round(tick.bid_price_2, 7))) self.bv2_label.setText(str(round(tick.bid_volume_2, 7))) self.ap2_label.setText(str(round(tick.ask_price_2, 7))) self.av2_label.setText(str(round(tick.ask_volume_2, 7)))