[bug fix] 修复数字货币精度,持仓

This commit is contained in:
msincenselee 2020-03-28 14:18:35 +08:00
parent db6e107b95
commit 7f7f1e97ea
8 changed files with 112 additions and 168 deletions

View File

@ -41,7 +41,8 @@ from vnpy.trader.object import (
TickData,
OrderData,
TradeData,
ContractData
ContractData,
PositionData
)
from vnpy.trader.constant import (
Exchange,
@ -127,10 +128,6 @@ class BackTestingEngine(object):
self.order_strategy_dict = {} # orderid 与 strategy的映射
# 持仓缓存字典
# key为vt_symbolvalue为PositionBuffer对象
self.pos_holding_dict = {}
self.trade_count = 0 # 成交编号
self.trade_dict = OrderedDict() # 用于统计成交收益时,还没处理得交易
self.trades = OrderedDict() # 记录所有得成交记录
@ -139,7 +136,7 @@ class BackTestingEngine(object):
self.long_position_list = [] # 多单持仓
self.short_position_list = [] # 空单持仓
self.holdings = {} # 多空持仓
self.positions = {} # 账号持仓对象为PositionData
# 当前最新数据,用于模拟成交用
self.gateway_name = u'BackTest'
@ -388,13 +385,13 @@ class BackTestingEngine(object):
def get_exchange(self, symbol: str):
return self.symbol_exchange_dict.get(symbol, Exchange.LOCAL)
def get_position_holding(self, vt_symbol: str, gateway_name: str = ''):
""" 查询合约在账号的持仓(包含多空)"""
def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''):
""" 查询合约在账号的持仓"""
if not gateway_name:
gateway_name = self.gateway_name
k = f'{gateway_name}.{vt_symbol}'
holding = self.holdings.get(k, None)
if not holding:
k = f'{gateway_name}.{vt_symbol}.{direction.value}'
pos = self.positions.get(k, None)
if not pos:
contract = self.get_contract(vt_symbol)
if not contract:
self.write_log(f'{vt_symbol}合约信息不存在,构造一个')
@ -413,9 +410,14 @@ class BackTestingEngine(object):
size=self.get_size(vt_symbol),
pricetick=self.get_price_tick(vt_symbol),
margin_rate=self.get_margin_rate(vt_symbol))
holding = PositionHolding(contract)
self.holdings[k] = holding
return holding
pos = PositionData(
gateway_name=gateway_name,
symbol=contract.symbol,
exchange=contract.exchange,
direction=direction
)
self.positions[k] = pos
return pos
def set_name(self, test_name):
"""
@ -1081,9 +1083,13 @@ class BackTestingEngine(object):
self.append_trade(trade)
# 更新持仓缓存数据
k = '.'.join([self.gateway_name, trade.vt_symbol])
holding = self.get_position_holding(trade.vt_symbol, self.gateway_name)
holding.update_trade(trade)
pos = self.get_position(vt_symbol=trade.vt_symbol,direction=Direction.NET)
pre_volume = pos.volume
if trade.direction == Direction.LONG:
pos.volume = round(pos.volume + trade.volume, 7)
else:
pos.volume = round(pos.volume - trade.volume, 7)
self.write_log(f'{trade.vt_symbol} volume:{pre_volume} => {pos.volume}')
strategy.on_trade(trade)
@ -1160,22 +1166,27 @@ class BackTestingEngine(object):
trade.strategy_name = strategy.strategy_name
# 更新持仓缓存数据
k = '.'.join([self.gateway_name, trade.vt_symbol])
holding = self.get_position_holding(trade.vt_symbol, self.gateway_name)
holding.update_trade(trade)
strategy.on_trade(trade)
pos = self.get_position(vt_symbol=trade.vt_symbol, direction=Direction.NET)
pre_volume = pos.volume
if trade.direction == Direction.LONG:
pos.volume = round(pos.volume + trade.volume, 7)
else:
pos.volume = round(pos.volume - trade.volume, 7)
self.write_log(f'{trade.vt_symbol} volume:{pre_volume} => {pos.volume}')
self.trade_dict[trade.vt_tradeid] = trade
self.trades[trade.vt_tradeid] = copy.copy(trade)
self.write_log(u'vt_trade_id:{0}'.format(trade.vt_tradeid))
self.write_log(u'{} : crossLimitOrder: TradeId:{}, posBuffer = {}'.format(trade.strategy_name,
self.write_log(u'{} : crossLimitOrder: TradeId:{}'.format(trade.strategy_name,
trade.tradeid,
holding.to_str()))
))
# 写入交易记录
self.append_trade(trade)
strategy.on_trade(trade)
# 更新资金曲线
fund_kline = self.get_fund_kline(trade.strategy_name)
if fund_kline:
@ -1193,20 +1204,6 @@ class BackTestingEngine(object):
# 实时计算模式
self.realtime_calculate()
def update_pos_buffer(self):
"""更新持仓信息,把今仓=>昨仓"""
for k, v in self.pos_holding_dict.items():
if v.long_td > 0:
self.write_log(u'调整多单持仓:今仓{}=> 0 昨仓{} => 昨仓:{}'.format(v.long_td, v.long_yd, v.long_pos))
v.long_td = 0
v.longYd = v.long_pos
if v.short_td > 0:
self.write_log(u'调整空单持仓:今仓{}=> 0 昨仓{} => 昨仓:{}'.format(v.short_td, v.short_yd, v.short_pos))
v.short_td = 0
v.short_yd = v.short_pos
def get_data_path(self):
"""
获取数据保存目录
@ -1842,19 +1839,6 @@ class BackTestingEngine(object):
else:
self.write_log(msg)
# 今仓 =》 昨仓
for holding in self.holdings.values():
if holding.long_td > 0:
self.write_log(
f'{holding.vt_symbol} 多单今仓{holding.long_td},昨仓:{holding.long_yd}=> 昨仓:{holding.long_pos}')
holding.long_td = 0
holding.long_yd = holding.long_pos
if holding.short_td > 0:
self.write_log(
f'{holding.vt_symbol} 空单今仓{holding.short_td},昨仓:{holding.short_yd}=> 昨仓:{holding.short_pos}')
holding.short_td = 0
holding.short_yd = holding.short_pos
# ---------------------------------------------------------------------
def export_trade_result(self):
"""

View File

@ -138,7 +138,7 @@ class CtaEngine(BaseEngine):
self.vt_tradeids = set() # for filtering duplicate trade
self.holdings = {}
self.positions = {}
self.last_minute = None
@ -243,9 +243,6 @@ class CtaEngine(BaseEngine):
""""""
order = event.data
holding = self.get_position_holding(order.vt_symbol, order.gateway_name)
holding.update_order(order)
strategy = self.orderid_strategy_map.get(order.vt_orderid, None)
if not strategy:
return
@ -282,9 +279,6 @@ class CtaEngine(BaseEngine):
return
self.vt_tradeids.add(trade.vt_tradeid)
holding = self.get_position_holding(trade.vt_symbol, trade.gateway_name)
holding.update_trade(trade)
strategy = self.orderid_strategy_map.get(trade.vt_orderid, None)
if not strategy:
return
@ -341,8 +335,7 @@ class CtaEngine(BaseEngine):
""""""
position = event.data
holding = self.get_position_holding(position.vt_symbol, position.gateway_name)
holding.update_position(position)
self.positions.update({position.vt_positionid: position})
def check_unsubscribed_symbols(self):
"""检查未订阅合约"""
@ -786,7 +779,7 @@ class CtaEngine(BaseEngine):
contract = self.main_engine.get_contract(vt_symbol)
if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息')
return 0.1
return 0.001
return contract.pricetick
@ -796,7 +789,7 @@ class CtaEngine(BaseEngine):
contract = self.main_engine.get_contract(vt_symbol)
if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息')
return 1
return 0.01
return contract.min_volume
@ -832,22 +825,14 @@ class CtaEngine(BaseEngine):
def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''):
""" 查询合约在账号的持仓,需要指定方向"""
contract = self.main_engine.get_contract(vt_symbol)
if contract:
if contract.gateway_name and not gateway_name:
gateway_name = contract.gateway_name
vt_position_id = f"{gateway_name}.{vt_symbol}.{direction.value}"
return self.main_engine.get_position(vt_position_id)
def get_position_holding(self, vt_symbol: str, gateway_name: str = ''):
""" 查询合约在账号的持仓(包含多空)"""
k = f'{gateway_name}.{vt_symbol}'
holding = self.holdings.get(k, None)
if not holding:
symbol, exchange = extract_vt_symbol(vt_symbol)
contract = self.main_engine.get_contract(vt_symbol)
holding = PositionHolding(contract)
self.holdings[k] = holding
return holding
def get_engine_type(self):
""""""
return self.engine_type
@ -1498,19 +1483,15 @@ class CtaEngine(BaseEngine):
compare_pos = dict() # vt_symbol: {'账号多单': xx, '账号空单':xxx, '策略空单':[], '策略多单':[]}
for holding_key in list(self.holdings.keys()):
for position in list(self.positions.values()):
# gateway_name.symbol.exchange => symbol.exchange
vt_symbol = '.'.join(holding_key.split('.')[-2:])
vt_symbol = position.vt_symbol
vt_symbols.add(vt_symbol)
holding = self.holdings.get(holding_key, None)
if holding is None:
continue
compare_pos[vt_symbol] = OrderedDict(
{
"账号空单": holding.short_pos,
'账号多单': holding.long_pos,
"账号空单": abs(position.volume) if position.volume < 0 else 0,
'账号多单': position.volume if position.volume > 0 else 0,
'策略空单': 0,
'策略多单': 0,
'空单策略': [],

View File

@ -279,8 +279,7 @@ class PortfolioTestingEngine(BackTestingEngine):
# 第二个交易日,撤单
self.cancel_orders()
# 更新持仓缓存
self.update_pos_buffer()
gc_collect_days += 1
if gc_collect_days >= 10:

View File

@ -997,7 +997,9 @@ class CtaFutureTemplate(CtaTemplate):
"""
self.write_log(u'执行事务平多仓位:{}'.format(grid.to_json()))
self.account_pos = self.cta_engine.get_position_holding(self.vt_symbol)
self.account_pos = self.cta_engine.get_position(
vt_symbol=self.vt_symbol,
direction=Direction.NET)
if self.account_pos is None:
self.write_error(u'无法获取{}得持仓信息'.format(self.vt_symbol))
@ -1050,7 +1052,9 @@ class CtaFutureTemplate(CtaTemplate):
"""
self.write_log(u'执行事务平空仓位:{}'.format(grid.to_json()))
self.account_pos = self.cta_engine.get_position_holding(self.vt_symbol)
self.account_pos = self.cta_engine.get_position(
vt_symbol=self.vt_symbol,
direction=Direction.NET)
if self.account_pos is None:
self.write_error(u'无法获取{}得持仓信息'.format(self.vt_symbol))
return False
@ -1064,17 +1068,21 @@ class CtaFutureTemplate(CtaTemplate):
# 发出cover委托
if grid.traded_volume > 0:
grid.volume -= grid.traded_volume
grid.volume = round(grid.volume, 7)
grid.traded_volume = 0
grid.volume = round(grid.volume, 7)
if 0 < abs(self.account_pos.short_pos) < grid.volume:
self.write_error(u'当前{}的空单持仓:{},不满足平仓目标:{}, 强制降低'
if self.account_pos.volume >= 0:
self.write_error(u'当前{}的净持仓:{},不能平空单'
.format(self.vt_symbol,
self.account_pos.short_pos,
self.account_pos.volume))
return False
if abs(self.account_pos.volume) < grid.volume:
self.write_error(u'当前{}的净持仓:{},不满足平仓目标:{}, 强制降低'
.format(self.vt_symbol,
self.account_pos.volume,
grid.volume))
grid.volume = abs(self.account_pos.short_pos)
grid.volume = abs(self.account_pos.volume)
vt_orderids = self.cover(
price=cover_price,

View File

@ -77,11 +77,12 @@ class CtaLineBar(object):
self.lineM = None # 1分钟K线
lineMSetting = {}
lineMSetting['name'] = u'M1'
lineMSetting['interval'] = Interval.MINUTE
lineMSetting['bar_interval'] = 60 # 1分钟对应60秒
lineMSetting['inputEma1Len'] = 7 # EMA线1的周期
lineMSetting['inputEma2Len'] = 21 # EMA线2的周期
lineMSetting['inputBollLen'] = 20 # 布林特线周期
lineMSetting['inputBollStdRate'] = 2 # 布林特线标准差
lineMSetting['para_ema1_len'] = 7 # EMA线1的周期
lineMSetting['para_ema2_len'] = 21 # EMA线2的周期
lineMSetting['para_boll_len'] = 20 # 布林特线周期
lineMSetting['para_boll_std_rate'] = 2 # 布林特线标准差
lineMSetting['price_tick'] = self.price_tick # 最小条
lineMSetting['underlying_symbol'] = self.underlying_symbol #商品短号
self.lineM = CtaLineBar(self, self.onBar, lineMSetting)
@ -207,7 +208,7 @@ class CtaLineBar(object):
# 修正精度
if self.price_tick < 1:
exponent = decimal.Decimal(str(self.price_tick))
self.round_n = max(abs(exponent.as_tuple().exponent), 4)
self.round_n = max(abs(exponent.as_tuple().exponent) + 2, 4)
# 导入卡尔曼过滤器
if self.para_active_kf:

View File

@ -494,10 +494,13 @@ class BinancefRestApi(RestClient):
for asset in data["assets"]:
account = AccountData(
accountid=asset["asset"],
balance=float(asset["walletBalance"]),
balance=float(asset["walletBalance"]) + float(asset["maintMargin"]),
frozen=float(asset["maintMargin"]),
holding_profit=float(asset['unrealizedProfit']),
gateway_name=self.gateway_name
)
# 修正vnpy AccountData
account.balance += account.holding_profit
if account.balance:
self.gateway.on_account(account)
@ -516,69 +519,18 @@ class BinancefRestApi(RestClient):
""""""
for d in data:
volume = float(d["positionAmt"])
position = PositionData(
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.NET,
volume=volume,
price=float(d["entryPrice"]),
pnl=float(d["unRealizedProfit"]),
gateway_name=self.gateway_name,
)
self.gateway.on_position(position)
if volume > 0:
long_position = PositionData(
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.LONG,
volume=abs(volume),
price=float(d["entryPrice"]),
pnl=float(d["unRealizedProfit"]),
gateway_name=self.gateway_name,
)
short_position = PositionData(
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.SHORT,
volume=0,
price=0,
pnl=0,
gateway_name=self.gateway_name,
)
elif volume < 0:
long_position = PositionData(
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.LONG,
volume=0,
price=0,
pnl=0,
gateway_name=self.gateway_name,
)
short_position = PositionData(
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.SHORT,
volume=abs(volume),
price=float(d["entryPrice"]),
pnl=float(d["unRealizedProfit"]),
gateway_name=self.gateway_name,
)
else:
long_position = PositionData(
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.LONG,
volume=0,
price=0,
pnl=0,
gateway_name=self.gateway_name,
)
short_position = PositionData(
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.SHORT,
volume=0,
price=0,
pnl=0,
gateway_name=self.gateway_name,
)
self.gateway.on_position(long_position)
self.gateway.on_position(short_position)
# self.gateway.write_log("持仓信息查询成功")
# self.gateway.write_log("持仓信息查询成功")
def on_query_order(self, data: dict, request: Request) -> None:
""""""
@ -831,15 +783,34 @@ class BinancefTradeWebsocketApi(WebsocketClient):
def on_account(self, packet: dict) -> None:
""""""
holding_pnl = 0
for pos_data in packet["a"]["P"]:
print(pos_data)
volume = float(pos_data["pa"])
position = PositionData(
symbol=pos_data["s"],
exchange=Exchange.BINANCE,
direction=Direction.NET,
volume=abs(volume),
price=float(pos_data["ep"]),
pnl=float(pos_data["cr"]),
gateway_name=self.gateway_name,
)
holding_pnl += float(pos_data['up'])
self.gateway.on_position(position)
for acc_data in packet["a"]["B"]:
account = AccountData(
accountid=acc_data["a"],
balance=float(acc_data["wb"]),
balance=round(float(acc_data["wb"]), 7),
frozen=float(acc_data["wb"]) - float(acc_data["cw"]),
holding_profit=round(holding_pnl, 7),
gateway_name=self.gateway_name
)
if account.balance:
account.balance += account.holding_profit
account.available = float(acc_data["cw"])
self.gateway.on_account(account)
for pos_data in packet["a"]["P"]:

View File

@ -130,11 +130,11 @@ class PositionHolding:
if position.direction == Direction.LONG:
self.long_pos = position.volume
self.long_yd = position.yd_volume
self.long_td = self.long_pos - self.long_yd
self.long_td = round(self.long_pos - self.long_yd, 7)
else:
self.short_pos = position.volume
self.short_yd = position.yd_volume
self.short_td = self.short_pos - self.short_yd
self.short_td = round(self.short_pos - self.short_yd, 7)
def update_order(self, order: OrderData) -> None:
""""""
@ -211,7 +211,7 @@ class PositionHolding:
if order.offset == Offset.OPEN:
continue
frozen = order.volume - order.traded
frozen = round(order.volume - order.traded, 7)
if order.direction == Direction.LONG:
if order.offset == Offset.CLOSETODAY:
@ -238,8 +238,8 @@ class PositionHolding:
- self.long_td)
self.long_td_frozen = self.long_td
self.long_pos_frozen = self.long_td_frozen + self.long_yd_frozen
self.short_pos_frozen = self.short_td_frozen + self.short_yd_frozen
self.long_pos_frozen = round(self.long_td_frozen + self.long_yd_frozen, 7)
self.short_pos_frozen = round(self.short_td_frozen + self.short_yd_frozen, 7)
def convert_order_request_shfe(self, req: OrderRequest) -> List[OrderRequest]:
"""上期所,委托单拆分"""

View File

@ -750,7 +750,7 @@ class TradingWidget(QtWidgets.QWidget):
self.return_label.setText(f"{r:.2f}%")
if tick.bid_price_2:
self.bp2_label.setText(str(round(tick.bid_price_2), 7))
self.bp2_label.setText(str(round(tick.bid_price_2, 7)))
self.bv2_label.setText(str(round(tick.bid_volume_2, 7)))
self.ap2_label.setText(str(round(tick.ask_price_2, 7)))
self.av2_label.setText(str(round(tick.ask_volume_2, 7)))