[bug fix] 修复数字货币精度,持仓

This commit is contained in:
msincenselee 2020-03-28 14:18:35 +08:00
parent db6e107b95
commit 7f7f1e97ea
8 changed files with 112 additions and 168 deletions

View File

@ -41,7 +41,8 @@ from vnpy.trader.object import (
TickData, TickData,
OrderData, OrderData,
TradeData, TradeData,
ContractData ContractData,
PositionData
) )
from vnpy.trader.constant import ( from vnpy.trader.constant import (
Exchange, Exchange,
@ -127,10 +128,6 @@ class BackTestingEngine(object):
self.order_strategy_dict = {} # orderid 与 strategy的映射 self.order_strategy_dict = {} # orderid 与 strategy的映射
# 持仓缓存字典
# key为vt_symbolvalue为PositionBuffer对象
self.pos_holding_dict = {}
self.trade_count = 0 # 成交编号 self.trade_count = 0 # 成交编号
self.trade_dict = OrderedDict() # 用于统计成交收益时,还没处理得交易 self.trade_dict = OrderedDict() # 用于统计成交收益时,还没处理得交易
self.trades = OrderedDict() # 记录所有得成交记录 self.trades = OrderedDict() # 记录所有得成交记录
@ -139,7 +136,7 @@ class BackTestingEngine(object):
self.long_position_list = [] # 多单持仓 self.long_position_list = [] # 多单持仓
self.short_position_list = [] # 空单持仓 self.short_position_list = [] # 空单持仓
self.holdings = {} # 多空持仓 self.positions = {} # 账号持仓对象为PositionData
# 当前最新数据,用于模拟成交用 # 当前最新数据,用于模拟成交用
self.gateway_name = u'BackTest' self.gateway_name = u'BackTest'
@ -388,13 +385,13 @@ class BackTestingEngine(object):
def get_exchange(self, symbol: str): def get_exchange(self, symbol: str):
return self.symbol_exchange_dict.get(symbol, Exchange.LOCAL) return self.symbol_exchange_dict.get(symbol, Exchange.LOCAL)
def get_position_holding(self, vt_symbol: str, gateway_name: str = ''): def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''):
""" 查询合约在账号的持仓(包含多空)""" """ 查询合约在账号的持仓"""
if not gateway_name: if not gateway_name:
gateway_name = self.gateway_name gateway_name = self.gateway_name
k = f'{gateway_name}.{vt_symbol}' k = f'{gateway_name}.{vt_symbol}.{direction.value}'
holding = self.holdings.get(k, None) pos = self.positions.get(k, None)
if not holding: if not pos:
contract = self.get_contract(vt_symbol) contract = self.get_contract(vt_symbol)
if not contract: if not contract:
self.write_log(f'{vt_symbol}合约信息不存在,构造一个') self.write_log(f'{vt_symbol}合约信息不存在,构造一个')
@ -413,9 +410,14 @@ class BackTestingEngine(object):
size=self.get_size(vt_symbol), size=self.get_size(vt_symbol),
pricetick=self.get_price_tick(vt_symbol), pricetick=self.get_price_tick(vt_symbol),
margin_rate=self.get_margin_rate(vt_symbol)) margin_rate=self.get_margin_rate(vt_symbol))
holding = PositionHolding(contract) pos = PositionData(
self.holdings[k] = holding gateway_name=gateway_name,
return holding symbol=contract.symbol,
exchange=contract.exchange,
direction=direction
)
self.positions[k] = pos
return pos
def set_name(self, test_name): def set_name(self, test_name):
""" """
@ -1081,9 +1083,13 @@ class BackTestingEngine(object):
self.append_trade(trade) self.append_trade(trade)
# 更新持仓缓存数据 # 更新持仓缓存数据
k = '.'.join([self.gateway_name, trade.vt_symbol]) pos = self.get_position(vt_symbol=trade.vt_symbol,direction=Direction.NET)
holding = self.get_position_holding(trade.vt_symbol, self.gateway_name) pre_volume = pos.volume
holding.update_trade(trade) if trade.direction == Direction.LONG:
pos.volume = round(pos.volume + trade.volume, 7)
else:
pos.volume = round(pos.volume - trade.volume, 7)
self.write_log(f'{trade.vt_symbol} volume:{pre_volume} => {pos.volume}')
strategy.on_trade(trade) strategy.on_trade(trade)
@ -1160,22 +1166,27 @@ class BackTestingEngine(object):
trade.strategy_name = strategy.strategy_name trade.strategy_name = strategy.strategy_name
# 更新持仓缓存数据 # 更新持仓缓存数据
k = '.'.join([self.gateway_name, trade.vt_symbol]) pos = self.get_position(vt_symbol=trade.vt_symbol, direction=Direction.NET)
holding = self.get_position_holding(trade.vt_symbol, self.gateway_name) pre_volume = pos.volume
holding.update_trade(trade) if trade.direction == Direction.LONG:
strategy.on_trade(trade) pos.volume = round(pos.volume + trade.volume, 7)
else:
pos.volume = round(pos.volume - trade.volume, 7)
self.write_log(f'{trade.vt_symbol} volume:{pre_volume} => {pos.volume}')
self.trade_dict[trade.vt_tradeid] = trade self.trade_dict[trade.vt_tradeid] = trade
self.trades[trade.vt_tradeid] = copy.copy(trade) self.trades[trade.vt_tradeid] = copy.copy(trade)
self.write_log(u'vt_trade_id:{0}'.format(trade.vt_tradeid)) self.write_log(u'vt_trade_id:{0}'.format(trade.vt_tradeid))
self.write_log(u'{} : crossLimitOrder: TradeId:{}, posBuffer = {}'.format(trade.strategy_name, self.write_log(u'{} : crossLimitOrder: TradeId:{}'.format(trade.strategy_name,
trade.tradeid, trade.tradeid,
holding.to_str())) ))
# 写入交易记录 # 写入交易记录
self.append_trade(trade) self.append_trade(trade)
strategy.on_trade(trade)
# 更新资金曲线 # 更新资金曲线
fund_kline = self.get_fund_kline(trade.strategy_name) fund_kline = self.get_fund_kline(trade.strategy_name)
if fund_kline: if fund_kline:
@ -1193,20 +1204,6 @@ class BackTestingEngine(object):
# 实时计算模式 # 实时计算模式
self.realtime_calculate() self.realtime_calculate()
def update_pos_buffer(self):
"""更新持仓信息,把今仓=>昨仓"""
for k, v in self.pos_holding_dict.items():
if v.long_td > 0:
self.write_log(u'调整多单持仓:今仓{}=> 0 昨仓{} => 昨仓:{}'.format(v.long_td, v.long_yd, v.long_pos))
v.long_td = 0
v.longYd = v.long_pos
if v.short_td > 0:
self.write_log(u'调整空单持仓:今仓{}=> 0 昨仓{} => 昨仓:{}'.format(v.short_td, v.short_yd, v.short_pos))
v.short_td = 0
v.short_yd = v.short_pos
def get_data_path(self): def get_data_path(self):
""" """
获取数据保存目录 获取数据保存目录
@ -1842,19 +1839,6 @@ class BackTestingEngine(object):
else: else:
self.write_log(msg) self.write_log(msg)
# 今仓 =》 昨仓
for holding in self.holdings.values():
if holding.long_td > 0:
self.write_log(
f'{holding.vt_symbol} 多单今仓{holding.long_td},昨仓:{holding.long_yd}=> 昨仓:{holding.long_pos}')
holding.long_td = 0
holding.long_yd = holding.long_pos
if holding.short_td > 0:
self.write_log(
f'{holding.vt_symbol} 空单今仓{holding.short_td},昨仓:{holding.short_yd}=> 昨仓:{holding.short_pos}')
holding.short_td = 0
holding.short_yd = holding.short_pos
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
def export_trade_result(self): def export_trade_result(self):
""" """

View File

@ -138,7 +138,7 @@ class CtaEngine(BaseEngine):
self.vt_tradeids = set() # for filtering duplicate trade self.vt_tradeids = set() # for filtering duplicate trade
self.holdings = {} self.positions = {}
self.last_minute = None self.last_minute = None
@ -243,9 +243,6 @@ class CtaEngine(BaseEngine):
"""""" """"""
order = event.data order = event.data
holding = self.get_position_holding(order.vt_symbol, order.gateway_name)
holding.update_order(order)
strategy = self.orderid_strategy_map.get(order.vt_orderid, None) strategy = self.orderid_strategy_map.get(order.vt_orderid, None)
if not strategy: if not strategy:
return return
@ -282,9 +279,6 @@ class CtaEngine(BaseEngine):
return return
self.vt_tradeids.add(trade.vt_tradeid) self.vt_tradeids.add(trade.vt_tradeid)
holding = self.get_position_holding(trade.vt_symbol, trade.gateway_name)
holding.update_trade(trade)
strategy = self.orderid_strategy_map.get(trade.vt_orderid, None) strategy = self.orderid_strategy_map.get(trade.vt_orderid, None)
if not strategy: if not strategy:
return return
@ -341,8 +335,7 @@ class CtaEngine(BaseEngine):
"""""" """"""
position = event.data position = event.data
holding = self.get_position_holding(position.vt_symbol, position.gateway_name) self.positions.update({position.vt_positionid: position})
holding.update_position(position)
def check_unsubscribed_symbols(self): def check_unsubscribed_symbols(self):
"""检查未订阅合约""" """检查未订阅合约"""
@ -786,7 +779,7 @@ class CtaEngine(BaseEngine):
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
if contract is None: if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息') self.write_error(f'查询不到{vt_symbol}合约信息')
return 0.1 return 0.001
return contract.pricetick return contract.pricetick
@ -796,7 +789,7 @@ class CtaEngine(BaseEngine):
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
if contract is None: if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息') self.write_error(f'查询不到{vt_symbol}合约信息')
return 1 return 0.01
return contract.min_volume return contract.min_volume
@ -832,22 +825,14 @@ class CtaEngine(BaseEngine):
def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''): def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''):
""" 查询合约在账号的持仓,需要指定方向""" """ 查询合约在账号的持仓,需要指定方向"""
contract = self.main_engine.get_contract(vt_symbol)
if contract:
if contract.gateway_name and not gateway_name:
gateway_name = contract.gateway_name
vt_position_id = f"{gateway_name}.{vt_symbol}.{direction.value}" vt_position_id = f"{gateway_name}.{vt_symbol}.{direction.value}"
return self.main_engine.get_position(vt_position_id) return self.main_engine.get_position(vt_position_id)
def get_position_holding(self, vt_symbol: str, gateway_name: str = ''):
""" 查询合约在账号的持仓(包含多空)"""
k = f'{gateway_name}.{vt_symbol}'
holding = self.holdings.get(k, None)
if not holding:
symbol, exchange = extract_vt_symbol(vt_symbol)
contract = self.main_engine.get_contract(vt_symbol)
holding = PositionHolding(contract)
self.holdings[k] = holding
return holding
def get_engine_type(self): def get_engine_type(self):
"""""" """"""
return self.engine_type return self.engine_type
@ -1498,19 +1483,15 @@ class CtaEngine(BaseEngine):
compare_pos = dict() # vt_symbol: {'账号多单': xx, '账号空单':xxx, '策略空单':[], '策略多单':[]} compare_pos = dict() # vt_symbol: {'账号多单': xx, '账号空单':xxx, '策略空单':[], '策略多单':[]}
for holding_key in list(self.holdings.keys()): for position in list(self.positions.values()):
# gateway_name.symbol.exchange => symbol.exchange # gateway_name.symbol.exchange => symbol.exchange
vt_symbol = '.'.join(holding_key.split('.')[-2:]) vt_symbol = position.vt_symbol
vt_symbols.add(vt_symbol) vt_symbols.add(vt_symbol)
holding = self.holdings.get(holding_key, None)
if holding is None:
continue
compare_pos[vt_symbol] = OrderedDict( compare_pos[vt_symbol] = OrderedDict(
{ {
"账号空单": holding.short_pos, "账号空单": abs(position.volume) if position.volume < 0 else 0,
'账号多单': holding.long_pos, '账号多单': position.volume if position.volume > 0 else 0,
'策略空单': 0, '策略空单': 0,
'策略多单': 0, '策略多单': 0,
'空单策略': [], '空单策略': [],

View File

@ -279,8 +279,7 @@ class PortfolioTestingEngine(BackTestingEngine):
# 第二个交易日,撤单 # 第二个交易日,撤单
self.cancel_orders() self.cancel_orders()
# 更新持仓缓存
self.update_pos_buffer()
gc_collect_days += 1 gc_collect_days += 1
if gc_collect_days >= 10: if gc_collect_days >= 10:

View File

@ -997,7 +997,9 @@ class CtaFutureTemplate(CtaTemplate):
""" """
self.write_log(u'执行事务平多仓位:{}'.format(grid.to_json())) self.write_log(u'执行事务平多仓位:{}'.format(grid.to_json()))
self.account_pos = self.cta_engine.get_position_holding(self.vt_symbol) self.account_pos = self.cta_engine.get_position(
vt_symbol=self.vt_symbol,
direction=Direction.NET)
if self.account_pos is None: if self.account_pos is None:
self.write_error(u'无法获取{}得持仓信息'.format(self.vt_symbol)) self.write_error(u'无法获取{}得持仓信息'.format(self.vt_symbol))
@ -1050,7 +1052,9 @@ class CtaFutureTemplate(CtaTemplate):
""" """
self.write_log(u'执行事务平空仓位:{}'.format(grid.to_json())) self.write_log(u'执行事务平空仓位:{}'.format(grid.to_json()))
self.account_pos = self.cta_engine.get_position_holding(self.vt_symbol) self.account_pos = self.cta_engine.get_position(
vt_symbol=self.vt_symbol,
direction=Direction.NET)
if self.account_pos is None: if self.account_pos is None:
self.write_error(u'无法获取{}得持仓信息'.format(self.vt_symbol)) self.write_error(u'无法获取{}得持仓信息'.format(self.vt_symbol))
return False return False
@ -1064,17 +1068,21 @@ class CtaFutureTemplate(CtaTemplate):
# 发出cover委托 # 发出cover委托
if grid.traded_volume > 0: if grid.traded_volume > 0:
grid.volume -= grid.traded_volume grid.volume -= grid.traded_volume
grid.volume = round(grid.volume, 7)
grid.traded_volume = 0 grid.traded_volume = 0
grid.volume = round(grid.volume, 7) if self.account_pos.volume >= 0:
self.write_error(u'当前{}的净持仓:{},不能平空单'
if 0 < abs(self.account_pos.short_pos) < grid.volume:
self.write_error(u'当前{}的空单持仓:{},不满足平仓目标:{}, 强制降低'
.format(self.vt_symbol, .format(self.vt_symbol,
self.account_pos.short_pos, self.account_pos.volume))
return False
if abs(self.account_pos.volume) < grid.volume:
self.write_error(u'当前{}的净持仓:{},不满足平仓目标:{}, 强制降低'
.format(self.vt_symbol,
self.account_pos.volume,
grid.volume)) grid.volume))
grid.volume = abs(self.account_pos.short_pos) grid.volume = abs(self.account_pos.volume)
vt_orderids = self.cover( vt_orderids = self.cover(
price=cover_price, price=cover_price,

View File

@ -77,11 +77,12 @@ class CtaLineBar(object):
self.lineM = None # 1分钟K线 self.lineM = None # 1分钟K线
lineMSetting = {} lineMSetting = {}
lineMSetting['name'] = u'M1' lineMSetting['name'] = u'M1'
lineMSetting['interval'] = Interval.MINUTE
lineMSetting['bar_interval'] = 60 # 1分钟对应60秒 lineMSetting['bar_interval'] = 60 # 1分钟对应60秒
lineMSetting['inputEma1Len'] = 7 # EMA线1的周期 lineMSetting['para_ema1_len'] = 7 # EMA线1的周期
lineMSetting['inputEma2Len'] = 21 # EMA线2的周期 lineMSetting['para_ema2_len'] = 21 # EMA线2的周期
lineMSetting['inputBollLen'] = 20 # 布林特线周期 lineMSetting['para_boll_len'] = 20 # 布林特线周期
lineMSetting['inputBollStdRate'] = 2 # 布林特线标准差 lineMSetting['para_boll_std_rate'] = 2 # 布林特线标准差
lineMSetting['price_tick'] = self.price_tick # 最小条 lineMSetting['price_tick'] = self.price_tick # 最小条
lineMSetting['underlying_symbol'] = self.underlying_symbol #商品短号 lineMSetting['underlying_symbol'] = self.underlying_symbol #商品短号
self.lineM = CtaLineBar(self, self.onBar, lineMSetting) self.lineM = CtaLineBar(self, self.onBar, lineMSetting)
@ -207,7 +208,7 @@ class CtaLineBar(object):
# 修正精度 # 修正精度
if self.price_tick < 1: if self.price_tick < 1:
exponent = decimal.Decimal(str(self.price_tick)) exponent = decimal.Decimal(str(self.price_tick))
self.round_n = max(abs(exponent.as_tuple().exponent), 4) self.round_n = max(abs(exponent.as_tuple().exponent) + 2, 4)
# 导入卡尔曼过滤器 # 导入卡尔曼过滤器
if self.para_active_kf: if self.para_active_kf:

View File

@ -494,10 +494,13 @@ class BinancefRestApi(RestClient):
for asset in data["assets"]: for asset in data["assets"]:
account = AccountData( account = AccountData(
accountid=asset["asset"], accountid=asset["asset"],
balance=float(asset["walletBalance"]), balance=float(asset["walletBalance"]) + float(asset["maintMargin"]),
frozen=float(asset["maintMargin"]), frozen=float(asset["maintMargin"]),
holding_profit=float(asset['unrealizedProfit']),
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
# 修正vnpy AccountData
account.balance += account.holding_profit
if account.balance: if account.balance:
self.gateway.on_account(account) self.gateway.on_account(account)
@ -516,69 +519,18 @@ class BinancefRestApi(RestClient):
"""""" """"""
for d in data: for d in data:
volume = float(d["positionAmt"]) volume = float(d["positionAmt"])
position = PositionData(
if volume > 0: symbol=d["symbol"],
long_position = PositionData( exchange=Exchange.BINANCE,
symbol=d["symbol"], direction=Direction.NET,
exchange=Exchange.BINANCE, volume=volume,
direction=Direction.LONG, price=float(d["entryPrice"]),
volume=abs(volume), pnl=float(d["unRealizedProfit"]),
price=float(d["entryPrice"]), gateway_name=self.gateway_name,
pnl=float(d["unRealizedProfit"]), )
gateway_name=self.gateway_name, self.gateway.on_position(position)
)
short_position = PositionData( # self.gateway.write_log("持仓信息查询成功")
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.SHORT,
volume=0,
price=0,
pnl=0,
gateway_name=self.gateway_name,
)
elif volume < 0:
long_position = PositionData(
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.LONG,
volume=0,
price=0,
pnl=0,
gateway_name=self.gateway_name,
)
short_position = PositionData(
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.SHORT,
volume=abs(volume),
price=float(d["entryPrice"]),
pnl=float(d["unRealizedProfit"]),
gateway_name=self.gateway_name,
)
else:
long_position = PositionData(
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.LONG,
volume=0,
price=0,
pnl=0,
gateway_name=self.gateway_name,
)
short_position = PositionData(
symbol=d["symbol"],
exchange=Exchange.BINANCE,
direction=Direction.SHORT,
volume=0,
price=0,
pnl=0,
gateway_name=self.gateway_name,
)
self.gateway.on_position(long_position)
self.gateway.on_position(short_position)
# self.gateway.write_log("持仓信息查询成功")
def on_query_order(self, data: dict, request: Request) -> None: def on_query_order(self, data: dict, request: Request) -> None:
"""""" """"""
@ -831,15 +783,34 @@ class BinancefTradeWebsocketApi(WebsocketClient):
def on_account(self, packet: dict) -> None: def on_account(self, packet: dict) -> None:
"""""" """"""
holding_pnl = 0
for pos_data in packet["a"]["P"]:
print(pos_data)
volume = float(pos_data["pa"])
position = PositionData(
symbol=pos_data["s"],
exchange=Exchange.BINANCE,
direction=Direction.NET,
volume=abs(volume),
price=float(pos_data["ep"]),
pnl=float(pos_data["cr"]),
gateway_name=self.gateway_name,
)
holding_pnl += float(pos_data['up'])
self.gateway.on_position(position)
for acc_data in packet["a"]["B"]: for acc_data in packet["a"]["B"]:
account = AccountData( account = AccountData(
accountid=acc_data["a"], accountid=acc_data["a"],
balance=float(acc_data["wb"]), balance=round(float(acc_data["wb"]), 7),
frozen=float(acc_data["wb"]) - float(acc_data["cw"]), frozen=float(acc_data["wb"]) - float(acc_data["cw"]),
holding_profit=round(holding_pnl, 7),
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
if account.balance: if account.balance:
account.balance += account.holding_profit
account.available = float(acc_data["cw"])
self.gateway.on_account(account) self.gateway.on_account(account)
for pos_data in packet["a"]["P"]: for pos_data in packet["a"]["P"]:

View File

@ -130,11 +130,11 @@ class PositionHolding:
if position.direction == Direction.LONG: if position.direction == Direction.LONG:
self.long_pos = position.volume self.long_pos = position.volume
self.long_yd = position.yd_volume self.long_yd = position.yd_volume
self.long_td = self.long_pos - self.long_yd self.long_td = round(self.long_pos - self.long_yd, 7)
else: else:
self.short_pos = position.volume self.short_pos = position.volume
self.short_yd = position.yd_volume self.short_yd = position.yd_volume
self.short_td = self.short_pos - self.short_yd self.short_td = round(self.short_pos - self.short_yd, 7)
def update_order(self, order: OrderData) -> None: def update_order(self, order: OrderData) -> None:
"""""" """"""
@ -211,7 +211,7 @@ class PositionHolding:
if order.offset == Offset.OPEN: if order.offset == Offset.OPEN:
continue continue
frozen = order.volume - order.traded frozen = round(order.volume - order.traded, 7)
if order.direction == Direction.LONG: if order.direction == Direction.LONG:
if order.offset == Offset.CLOSETODAY: if order.offset == Offset.CLOSETODAY:
@ -238,8 +238,8 @@ class PositionHolding:
- self.long_td) - self.long_td)
self.long_td_frozen = self.long_td self.long_td_frozen = self.long_td
self.long_pos_frozen = self.long_td_frozen + self.long_yd_frozen self.long_pos_frozen = round(self.long_td_frozen + self.long_yd_frozen, 7)
self.short_pos_frozen = self.short_td_frozen + self.short_yd_frozen self.short_pos_frozen = round(self.short_td_frozen + self.short_yd_frozen, 7)
def convert_order_request_shfe(self, req: OrderRequest) -> List[OrderRequest]: def convert_order_request_shfe(self, req: OrderRequest) -> List[OrderRequest]:
"""上期所,委托单拆分""" """上期所,委托单拆分"""

View File

@ -750,7 +750,7 @@ class TradingWidget(QtWidgets.QWidget):
self.return_label.setText(f"{r:.2f}%") self.return_label.setText(f"{r:.2f}%")
if tick.bid_price_2: if tick.bid_price_2:
self.bp2_label.setText(str(round(tick.bid_price_2), 7)) self.bp2_label.setText(str(round(tick.bid_price_2, 7)))
self.bv2_label.setText(str(round(tick.bid_volume_2, 7))) self.bv2_label.setText(str(round(tick.bid_volume_2, 7)))
self.ap2_label.setText(str(round(tick.ask_price_2, 7))) self.ap2_label.setText(str(round(tick.ask_price_2, 7)))
self.av2_label.setText(str(round(tick.ask_volume_2, 7))) self.av2_label.setText(str(round(tick.ask_volume_2, 7)))