[增强功能] ctp/sopt增加期权得动态权益,修正资金净值,界面增加当前价;cta_strategy_pro增加获取所有合约接口

This commit is contained in:
msincenselee 2020-08-28 11:37:45 +08:00
parent 7baa486d6b
commit c890b931b7
10 changed files with 192 additions and 42 deletions

View File

@ -2273,6 +2273,8 @@ class BackTestingEngine(object):
self.trade_dict.clear()
self.trades.clear()
self.trade_pnl_list = []
self.last_bar.clear()
self.last_dt = None
def append_trade(self, trade: TradeData):
"""

View File

@ -868,6 +868,9 @@ class CtaEngine(BaseEngine):
def get_contract(self, vt_symbol):
return self.main_engine.get_contract(vt_symbol)
def get_all_contracts(self):
return self.main_engine.get_all_contracts()
def get_account(self, vt_accountid: str = ""):
""" 查询账号的资金"""
# 如果启动风控,则使用风控中的最大仓位

View File

@ -188,6 +188,7 @@ class PortfolioTestingEngine(BackTestingEngine):
self.cur_capital = self.init_capital # 更新设置期初资金
if not self.data_end_date:
self.data_end_date = datetime.today()
self.test_end_date = datetime.now().strftime('%Y%m%d')
# 保存回测脚本到数据库
self.save_setting_to_mongo()
@ -275,6 +276,7 @@ class PortfolioTestingEngine(BackTestingEngine):
bar.high_price = float(bar_data['high'])
bar.low_price = float(bar_data['low'])
bar.volume = int(bar_data['volume'])
bar.open_interest = int(bar_data.get('open_interest', 0))
bar.date = bar_datetime.strftime('%Y-%m-%d')
bar.time = bar_datetime.strftime('%H:%M:%S')
str_td = str(bar_data.get('trading_day', ''))

View File

@ -1154,7 +1154,12 @@ class CtaProFutureTemplate(CtaProTemplate):
self.put_event()
def on_trade(self, trade: TradeData):
"""交易更新"""
"""
交易更新
支持股指期货的对锁单或者解锁
:param trade:
:return:
"""
self.write_log(u'{},交易更新事件:{},当前持仓:{} '
.format(self.cur_datetime,
trade.__dict__,
@ -1168,7 +1173,22 @@ class CtaProFutureTemplate(CtaProTemplate):
dist_record['volume'] = trade.volume
dist_record['price'] = trade.price
dist_record['symbol'] = trade.vt_symbol
if trade.exchange == Exchange.CFFEX:
if trade.direction == Direction.LONG:
if abs(self.position.short_pos) >= trade.volume:
self.position.short_pos += trade.volume
else:
self.position.long_pos += trade.volume
else:
if self.position.long_pos >= trade.volume:
self.position.long_pos -= trade.volume
else:
self.position.short_pos -= trade.volume
self.position.pos = self.position.long_pos + self.position.short_pos
dist_record['long_pos'] = self.position.long_pos
dist_record['short_pos'] = self.position.short_pos
else:
if trade.direction == Direction.LONG and trade.offset == Offset.OPEN:
dist_record['operation'] = 'buy'
self.position.open_pos(trade.direction, volume=trade.volume)

View File

@ -71,12 +71,21 @@ class CtaSpreadTemplate(CtaTemplate):
self.act_price_tick = None # 主动合约价格跳动
self.pas_price_tick = None # 被动合约价格跳动
self.act_symbol_size = None
self.pas_symbol_size = None
self.act_margin_rate = None
self.pas_margin_rate = None
self.act_pos = None # 主动合约得holding pos
self.pas_pos = None # 被动合约得holding pos
self.last_minute = None # 最后的分钟,用于on_tick内每分钟处理的逻辑
# 资金相关
self.max_invest_rate = 0.1 # 最大仓位(0~1)
self.max_invest_margin = 0 # 资金上限 0不限制
self.max_invest_pos = 0 # 单向头寸数量上限 0不限制
def update_setting(self, setting: dict):
"""更新配置参数"""
super().update_setting(setting)
@ -85,6 +94,10 @@ class CtaSpreadTemplate(CtaTemplate):
self.pas_symbol, self.pas_exchange = extract_vt_symbol(self.pas_vt_symbol)
self.act_price_tick = self.cta_engine.get_price_tick(self.act_vt_symbol)
self.pas_price_tick = self.cta_engine.get_price_tick(self.pas_vt_symbol)
self.act_symbol_size = self.cta_engine.get_size(self.act_vt_symbol)
self.pas_symbol_size = self.cta_engine.get_size(self.pas_vt_symbol)
self.act_margin_rate = self.cta_engine.get_margin_rate(self.act_vt_symbol)
self.pas_margin_rate = self.cta_engine.get_margin_rate(self.pas_vt_symbol)
# 实盘采用FAK
if not self.backtesting and self.activate_fak:
@ -467,6 +480,11 @@ class CtaSpreadTemplate(CtaTemplate):
except Exception as ex:
self.write_error(u'save_tns 异常:{} {}'.format(str(ex), traceback.format_exc()))
def save_data(self):
"""保存过程数据"""
if not self.backtesting:
return
def send_wechat(self, msg: str):
"""实盘时才发送微信"""
if self.backtesting:
@ -529,7 +547,7 @@ class CtaSpreadTemplate(CtaTemplate):
if trade.offset == Offset.OPEN:
# 更新开仓均价/数量
if trade.vt_symbol == self.act_vt_symbol:
opened_price = grid.snapshot.get('act_open_price', grid.volume * self.act_vol_ratio)
opened_price = grid.snapshot.get('act_open_price', 0)
opened_volume = grid.snapshot.get('act_open_volume', grid.volume * self.act_vol_ratio)
act_open_volume = opened_volume + trade.volume
act_open_price = (opened_price * opened_volume + trade.price * trade.volume) / act_open_volume
@ -632,6 +650,9 @@ class CtaSpreadTemplate(CtaTemplate):
# 在策略得活动订单中,移除
self.active_orders.pop(order.vt_orderid, None)
self.gt.save()
if len(self.active_orders) < 1:
self.entrust = 0
return
def on_order_open_canceled(self, order: OrderData):
"""
@ -1202,12 +1223,16 @@ class CtaSpreadTemplate(CtaTemplate):
self.write_error(f'spd_short{self.pas_vt_symbol}开多仓{grid.volume * self.pas_vol_ratio}手失败,'
f'委托价:{self.cur_pas_tick.ask_price_1}')
return []
grid.snapshot.update({"act_vt_symbol": self.act_vt_symbol, "act_open_volume": grid.volume * self.act_vol_ratio,
"pas_vt_symbol": self.pas_vt_symbol, "pas_open_volume": grid.volume * self.pas_vol_ratio})
# WJ: update_grid_trade() 中会根据实际交易的数目更新 act_open_volume & pas_open_volume
# 所以这里必须设置为初始值0否则grid中的 open_volume会是实际持仓的2倍导致spd_sell & spd_cover时失败
# grid.snapshot.update({"act_vt_symbol": self.act_vt_symbol, "act_open_volume": grid.volume * self.act_vol_ratio,
# "pas_vt_symbol": self.pas_vt_symbol, "pas_open_volume": grid.volume * self.pas_vol_ratio})
grid.snapshot.update({"act_vt_symbol": self.act_vt_symbol, "act_open_volume": 0,
"pas_vt_symbol": self.pas_vt_symbol, "pas_open_volume": 0})
grid.order_status = True
grid.order_datetime = self.cur_datetime
vt_orderids = act_vt_orderids.extend(pas_vt_orderids)
vt_orderids = act_vt_orderids + pas_vt_orderids # 不能用act_vt_orderids.extend(pas_vt_orderids),它的返回值为 None会导致没有vt_orderids
self.write_log(u'spd short vt_order_ids{0}'.format(vt_orderids))
return vt_orderids
@ -1264,11 +1289,11 @@ class CtaSpreadTemplate(CtaTemplate):
self.write_error(f'spd_short{self.pas_vt_symbol}开空仓{grid.volume * self.pas_vol_ratio}手失败,'
f'委托价:{self.cur_pas_tick.bid_price_1}')
return []
grid.snapshot.update({"act_vt_symbol": self.act_vt_symbol, "act_open_volume": grid.volume * self.act_vol_ratio,
"pas_vt_symbol": self.pas_vt_symbol, "pas_open_volume": grid.volume * self.pas_vol_ratio})
grid.snapshot.update({"act_vt_symbol": self.act_vt_symbol, "act_open_volume": 0,
"pas_vt_symbol": self.pas_vt_symbol, "pas_open_volume": 0})
grid.order_status = True
grid.order_datetime = self.cur_datetime
vt_orderids = act_vt_orderids.extend(pas_vt_orderids)
vt_orderids = act_vt_orderids + pas_vt_orderids
self.write_log(u'spd buy vt_ordderids{}'.format(vt_orderids))
return vt_orderids
@ -1342,7 +1367,7 @@ class CtaSpreadTemplate(CtaTemplate):
grid.order_status = True
grid.order_datetime = self.cur_datetime
vt_orderids = act_vt_orderids.extend(pas_vt_orderids)
vt_orderids = act_vt_orderids + pas_vt_orderids
self.write_log(f'spd sell vt_orderids{vt_orderids}')
return vt_orderids
@ -1415,6 +1440,6 @@ class CtaSpreadTemplate(CtaTemplate):
grid.order_status = True
grid.order_datetime = self.cur_datetime
vt_orderids = act_vt_orderids.extend(pas_vt_orderids)
vt_orderids = act_vt_orderids + pas_vt_orderids
self.write_log(f'spd cover vt_orderids{vt_orderids}')
return vt_orderids

View File

@ -214,6 +214,8 @@ class BinancefRestApi(RestClient):
self.orders = {}
self.cache_position_symbols = {}
self.accountid = ""
def sign(self, request: Request) -> Request:
@ -561,6 +563,15 @@ class BinancefRestApi(RestClient):
pnl=float(d["unRealizedProfit"]),
gateway_name=self.gateway_name,
)
# 如果持仓数量为0且不在之前缓存过的合约信息中不做on_position
if position.volume == 0:
if position.symbol not in self.cache_position_symbols:
continue
else:
if position.symbol not in self.cache_position_symbols:
self.cache_position_symbols.update({position.symbol: position.volume})
self.gateway.on_position(position)
#if position.symbol == 'BTCUSDT':
# self.gateway.write_log(f'{position.__dict__}\n {d}')

View File

@ -154,6 +154,7 @@ OPTIONTYPE_CTP2VT = {
MAX_FLOAT = sys.float_info.max
symbol_exchange_map = {}
option_name_map = {}
symbol_name_map = {}
symbol_size_map = {}
index_contracts = {}
@ -764,6 +765,9 @@ class CtpTdApi(TdApi):
self.accountid = self.userid
self.long_option_cost = None # 多头期权动态市值
self.short_option_cost = None # 空头期权动态市值
def onFrontConnected(self):
""""""
self.gateway.write_log("交易服务器连接成功")
@ -903,6 +907,9 @@ class CtpTdApi(TdApi):
# Update new position volume
position.volume += data["Position"]
if data["PositionProfit"] == 0 and position.symbol in option_name_map:
position.pnl += data["PositionCost"] - data["OpenCost"]
else:
position.pnl += data["PositionProfit"]
# Calculate average position price
@ -916,8 +923,34 @@ class CtpTdApi(TdApi):
else:
position.frozen += data["LongFrozen"]
position.cur_price = self.gateway.prices.get(position.vt_symbol, None)
if position.cur_price is None:
position.cur_price = position.price
self.gateway.subscribe(SubscribeRequest(symbol=position.symbol, exchange=position.exchange))
if last:
self.long_option_cost = None
self.short_option_cost = None
for position in self.positions.values():
if position.symbol in option_name_map:
# 重新累计多头期权动态权益
if position.direction == Direction.LONG:
if self.long_option_cost is None:
self.long_option_cost = position.cur_price * position.volume * symbol_size_map.get(
position.symbol, 0)
else:
self.long_option_cost += position.cur_price * position.volume * symbol_size_map.get(
position.symbol, 0)
# 重新累计空头期权动态权益
if position.direction == Direction.SHORT:
if self.short_option_cost is None:
self.short_option_cost = position.cur_price * position.volume * symbol_size_map.get(
position.symbol, 0)
else:
self.short_option_cost += position.cur_price * position.volume * symbol_size_map.get(
position.symbol, 0)
self.gateway.on_position(position)
self.positions.clear()
@ -929,10 +962,16 @@ class CtpTdApi(TdApi):
if len(self.accountid)== 0:
self.accountid = data['AccountID']
balance = float(data["Balance"])
if self.long_option_cost is not None:
balance += self.long_option_cost
if self.short_option_cost is not None:
balance -= self.short_option_cost
account = AccountData(
accountid=data["AccountID"],
pre_balance=round(float(data['PreBalance']), 7),
balance=round(float(data["Balance"]), 7),
balance=round(balance, 7),
frozen=round(data["FrozenMargin"] + data["FrozenCash"] + data["FrozenCommission"], 7),
gateway_name=self.gateway_name
)
@ -986,6 +1025,7 @@ class CtpTdApi(TdApi):
contract.option_strike = data["StrikePrice"]
contract.option_index = str(data["StrikePrice"])
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d")
option_name_map[contract.symbol] = contract.name
self.gateway.on_contract(contract)

View File

@ -1167,7 +1167,7 @@ class TqMdApi():
return
try:
from tqsdk import TqApi
self.api = TqApi(_stock=True)
self.api = TqApi(_stock=True,url="wss://u.shinnytech.com/t/nfmd/front/mobile")
except Exception as e:
self.gateway.write_log(f'天勤股票行情API接入异常:'.format(str(e)))
self.gateway.write_log(traceback.format_exc())

View File

@ -125,7 +125,7 @@ CHINA_TZ = pytz.timezone("Asia/Shanghai")
symbol_exchange_map = {}
symbol_name_map = {}
symbol_size_map = {}
option_name_map = {}
class SoptGateway(BaseGateway):
"""
@ -388,7 +388,7 @@ class SoptMdApi(MdApi):
return
timestamp = f"{data['TradingDay']} {data['UpdateTime']}.{int(data['UpdateMillisec']/100)}"
dt = datetime.strptime(timestamp, "%Y%m%d %H:%M:%S.%f")
dt = CHINA_TZ.localize(dt)
#dt = CHINA_TZ.localize(dt)
tick = TickData(
symbol=symbol,
@ -523,6 +523,9 @@ class SoptTdApi(TdApi):
self.positions = {}
self.sysid_orderid_map = {}
self.long_option_cost = None # 多头期权动态市值
self.short_option_cost = None # 空头期权动态市值
def onFrontConnected(self):
""""""
self.gateway.write_log("交易服务器连接成功")
@ -618,11 +621,14 @@ class SoptTdApi(TdApi):
if not data:
return
#self.gateway.write_log(print_dict(data))
# Get buffered position object
key = f"{data['InstrumentID'], data['PosiDirection']}"
position = self.positions.get(key, None)
if not position:
position = PositionData(
accountid=self.userid,
symbol=data["InstrumentID"],
exchange=symbol_exchange_map[data["InstrumentID"]],
direction=DIRECTION_SOPT2VT[data["PosiDirection"]],
@ -646,6 +652,9 @@ class SoptTdApi(TdApi):
# Update new position volume
position.volume += data["Position"]
if data["PositionProfit"] == 0:
position.pnl += data["PositionCost"] - data["OpenCost"]
else:
position.pnl += data["PositionProfit"]
# Calculate average position price
@ -659,27 +668,63 @@ class SoptTdApi(TdApi):
else:
position.frozen += data["LongFrozen"]
position.cur_price = self.gateway.prices.get(position.vt_symbol, None)
if position.cur_price is None:
position.cur_price = position.price
self.gateway.subscribe(SubscribeRequest(symbol=position.symbol, exchange=position.exchange))
if last:
self.long_option_cost = None
self.short_option_cost = None
for position in self.positions.values():
if position.symbol in option_name_map:
# 重新累计多头期权动态权益
if position.direction == Direction.LONG:
if self.long_option_cost is None:
self.long_option_cost = position.cur_price * position.volume * symbol_size_map.get(position.symbol, 0)
else:
self.long_option_cost += position.cur_price * position.volume * symbol_size_map.get(position.symbol, 0)
# 重新累计空头期权动态权益
if position.direction == Direction.SHORT:
if self.short_option_cost is None:
self.short_option_cost = position.cur_price * position.volume * symbol_size_map.get(position.symbol, 0)
else:
self.short_option_cost += position.cur_price * position.volume * symbol_size_map.get(position.symbol, 0)
self.gateway.on_position(position)
self.positions.clear()
def onRspQryTradingAccount(self, data: dict, error: dict, reqid: int, last: bool):
""""""
balance = float(data["Balance"])
# 资金差额权利金正数是卖call或卖put收入权利金; 负数是买call、买put付出权利金
cash_in = data.get('CashIn')
#balance -= cash_in
if self.long_option_cost is not None:
balance += self.long_option_cost
if self.short_option_cost is not None:
balance -= self.short_option_cost
account = AccountData(
accountid=data["AccountID"],
balance=data["Balance"],
balance=balance,
frozen=data["FrozenMargin"] + data["FrozenCash"] + data["FrozenCommission"],
gateway_name=self.gateway_name
)
account.available = data["Available"]
account.commission = round(float(data['Commission']), 7)
account.margin = round(float(data['CurrMargin']), 7)
account.close_profit = round(float(data['CloseProfit']), 7)
account.holding_profit = round(float(data['PositionProfit']), 7)
account.trading_day = str(data.get('TradingDay',datetime.now().strftime('%Y-%m-%d')))
#self.gateway.write_log(print_dict(data))
account.available = data["Available"]
account.commission = round(float(data['Commission']), 7) + round(float(data['SpecProductCommission']), 7)
account.margin = round(float(data['CurrMargin']), 7)
account.close_profit = round(float(data['CloseProfit']), 7) + round(float(data['SpecProductCloseProfit']), 7)
account.holding_profit = round(float(data['PositionProfit']), 7) + round(float(data['SpecProductPositionProfit']), 7)
account.trading_day = str(data.get('TradingDay', datetime.now().strftime('%Y-%m-%d')))
if '-' not in account.trading_day and len(account.trading_day) == 8:
account.trading_day = '-'.join(
[
@ -701,7 +746,7 @@ class SoptTdApi(TdApi):
contract = ContractData(
symbol=data["InstrumentID"],
exchange=EXCHANGE_SOPT2VT[data["ExchangeID"]],
name=data["InstrumentName"],
name=data["InstrumentName"].strip(),
product=product,
size=data["VolumeMultiple"],
pricetick=data["PriceTick"],
@ -724,6 +769,7 @@ class SoptTdApi(TdApi):
contract.option_index = get_option_index(
contract.option_strike, data["InstrumentCode"]
)
option_name_map[contract.symbol] = contract.name
self.gateway.on_contract(contract)

View File

@ -467,6 +467,7 @@ class PositionMonitor(BaseMonitor):
"volume": {"display": "数量", "cell": BaseCell, "update": True},
"yd_volume": {"display": "昨仓", "cell": BaseCell, "update": True},
"frozen": {"display": "冻结", "cell": BaseCell, "update": True},
"cur_price": {"display": "当前价", "cell": BaseCell, "update": True},
"price": {"display": "均价", "cell": BaseCell, "update": True},
"pnl": {"display": "盈亏", "cell": PnlCell, "update": True},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},