diff --git a/vnpy/app/spread_trading/base.py b/vnpy/app/spread_trading/base.py index 5eff4b48..e6777d30 100644 --- a/vnpy/app/spread_trading/base.py +++ b/vnpy/app/spread_trading/base.py @@ -172,24 +172,34 @@ class SpreadData: def calculate_pos(self): """""" - self.net_pos = 0 + long_pos = 0 + short_pos = 0 for n, leg in enumerate(self.legs.values()): + leg_long_pos = 0 + leg_short_pos = 0 + trading_multiplier = self.trading_multipliers[leg.vt_symbol] adjusted_net_pos = leg.net_pos / trading_multiplier if adjusted_net_pos > 0: adjusted_net_pos = floor(adjusted_net_pos) + leg_long_pos = adjusted_net_pos else: adjusted_net_pos = ceil(adjusted_net_pos) + leg_short_pos = abs(adjusted_net_pos) if not n: - self.net_pos = adjusted_net_pos + long_pos = leg_long_pos + short_pos = leg_short_pos else: - if adjusted_net_pos > 0: - self.net_pos = min(self.net_pos, adjusted_net_pos) - else: - self.net_pos = max(self.net_pos, adjusted_net_pos) + long_pos = min(long_pos, leg_long_pos) + short_pos = min(short_pos, leg_short_pos) + + if long_pos > 0: + self.net_pos = long_pos + else: + self.net_pos = short_pos def clear_price(self): """"""