507 lines
14 KiB
Python
507 lines
14 KiB
Python
"""
|
|
"""
|
|
|
|
import logging
|
|
import smtplib
|
|
from abc import ABC
|
|
from datetime import datetime
|
|
from email.message import EmailMessage
|
|
from queue import Empty, Queue
|
|
from threading import Thread
|
|
from typing import Any
|
|
|
|
from vnpy.event import Event, EventEngine
|
|
from .app import BaseApp
|
|
from .event import (
|
|
EVENT_TICK,
|
|
EVENT_ORDER,
|
|
EVENT_TRADE,
|
|
EVENT_POSITION,
|
|
EVENT_ACCOUNT,
|
|
EVENT_CONTRACT,
|
|
EVENT_LOG
|
|
)
|
|
from .gateway import BaseGateway
|
|
from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest
|
|
from .setting import SETTINGS
|
|
from .utility import Singleton, get_folder_path
|
|
|
|
|
|
class MainEngine:
|
|
"""
|
|
Acts as the core of VN Trader.
|
|
"""
|
|
|
|
def __init__(self, event_engine: EventEngine = None):
|
|
""""""
|
|
if event_engine:
|
|
self.event_engine = event_engine
|
|
else:
|
|
self.event_engine = EventEngine()
|
|
self.event_engine.start()
|
|
|
|
self.gateways = {}
|
|
self.engines = {}
|
|
self.apps = {}
|
|
|
|
self.init_engines()
|
|
|
|
def add_engine(self, engine_class: Any):
|
|
"""
|
|
Add function engine.
|
|
"""
|
|
engine = engine_class(self, self.event_engine)
|
|
self.engines[engine.engine_name] = engine
|
|
|
|
def add_gateway(self, gateway_class: BaseGateway):
|
|
"""
|
|
Add gateway.
|
|
"""
|
|
gateway = gateway_class(self.event_engine)
|
|
self.gateways[gateway.gateway_name] = gateway
|
|
|
|
def add_app(self, app_class: BaseApp):
|
|
"""
|
|
Add app.
|
|
"""
|
|
app = app_class()
|
|
self.apps[app.app_name] = app
|
|
|
|
self.add_engine(app.engine_class)
|
|
|
|
def init_engines(self):
|
|
"""
|
|
Init all engines.
|
|
"""
|
|
self.add_engine(LogEngine)
|
|
self.add_engine(OmsEngine)
|
|
self.add_engine(EmailEngine)
|
|
|
|
def write_log(self, msg: str, source: str = ""):
|
|
"""
|
|
Put log event with specific message.
|
|
"""
|
|
log = LogData(msg=msg, gateway_name=source)
|
|
event = Event(EVENT_LOG, log)
|
|
self.event_engine.put(event)
|
|
|
|
def get_gateway(self, gateway_name: str):
|
|
"""
|
|
Return gateway object by name.
|
|
"""
|
|
gateway = self.gateways.get(gateway_name, None)
|
|
if not gateway:
|
|
self.write_log(f"找不到底层接口:{gateway_name}")
|
|
return gateway
|
|
|
|
def get_engine(self, engine_name: str):
|
|
"""
|
|
Return engine object by name.
|
|
"""
|
|
engine = self.engines.get(engine_name, None)
|
|
if not engine:
|
|
self.write_log(f"找不到引擎:{engine_name}")
|
|
return engine
|
|
|
|
def get_default_setting(self, gateway_name: str):
|
|
"""
|
|
Get default setting dict of a specific gateway.
|
|
"""
|
|
gateway = self.get_gateway(gateway_name)
|
|
if gateway:
|
|
return gateway.get_default_setting()
|
|
return None
|
|
|
|
def get_all_gateway_names(self):
|
|
"""
|
|
Get all names of gatewasy added in main engine.
|
|
"""
|
|
return list(self.gateways.keys())
|
|
|
|
def get_all_apps(self):
|
|
"""
|
|
Get all app objects.
|
|
"""
|
|
return list(self.apps.values())
|
|
|
|
def connect(self, setting: dict, gateway_name: str):
|
|
"""
|
|
Start connection of a specific gateway.
|
|
"""
|
|
gateway = self.get_gateway(gateway_name)
|
|
if gateway:
|
|
gateway.connect(setting)
|
|
|
|
def subscribe(self, req: SubscribeRequest, gateway_name: str):
|
|
"""
|
|
Subscribe tick data update of a specific gateway.
|
|
"""
|
|
gateway = self.get_gateway(gateway_name)
|
|
if gateway:
|
|
gateway.subscribe(req)
|
|
|
|
def send_order(self, req: OrderRequest, gateway_name: str):
|
|
"""
|
|
Send new order request to a specific gateway.
|
|
"""
|
|
gateway = self.get_gateway(gateway_name)
|
|
if gateway:
|
|
return gateway.send_order(req)
|
|
else:
|
|
return ""
|
|
|
|
def cancel_order(self, req: CancelRequest, gateway_name: str):
|
|
"""
|
|
Send cancel order request to a specific gateway.
|
|
"""
|
|
gateway = self.get_gateway(gateway_name)
|
|
if gateway:
|
|
gateway.cancel_order(req)
|
|
|
|
def close(self):
|
|
"""
|
|
Make sure every gateway and app is closed properly before
|
|
programme exit.
|
|
"""
|
|
# Stop event engine first to prevent new timer event.
|
|
self.event_engine.stop()
|
|
|
|
for engine in self.engines.values():
|
|
engine.close()
|
|
|
|
for gateway in self.gateways.values():
|
|
gateway.close()
|
|
|
|
|
|
class BaseEngine(ABC):
|
|
"""
|
|
Abstract class for implementing an function engine.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
main_engine: MainEngine,
|
|
event_engine: EventEngine,
|
|
engine_name: str,
|
|
):
|
|
""""""
|
|
self.main_engine = main_engine
|
|
self.event_engine = event_engine
|
|
self.engine_name = engine_name
|
|
|
|
def close(self):
|
|
""""""
|
|
pass
|
|
|
|
|
|
class LogEngine(BaseEngine):
|
|
"""
|
|
Processes log event and output with logging module.
|
|
"""
|
|
|
|
__metaclass__ = Singleton
|
|
|
|
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
|
""""""
|
|
super(LogEngine, self).__init__(main_engine, event_engine, "log")
|
|
|
|
if not SETTINGS["log.active"]:
|
|
return
|
|
|
|
self.level = SETTINGS["log.level"]
|
|
|
|
self.logger = logging.getLogger("VN Trader")
|
|
self.logger.setLevel(self.level)
|
|
|
|
self.formatter = logging.Formatter(
|
|
"%(asctime)s %(levelname)s: %(message)s"
|
|
)
|
|
|
|
self.add_null_handler()
|
|
|
|
if SETTINGS["log.console"]:
|
|
self.add_console_handler()
|
|
|
|
if SETTINGS["log.file"]:
|
|
self.add_file_handler()
|
|
|
|
self.register_event()
|
|
|
|
def add_null_handler(self):
|
|
"""
|
|
Add null handler for logger.
|
|
"""
|
|
null_handler = logging.NullHandler()
|
|
self.logger.addHandler(null_handler)
|
|
|
|
def add_console_handler(self):
|
|
"""
|
|
Add console output of log.
|
|
"""
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setLevel(self.level)
|
|
console_handler.setFormatter(self.formatter)
|
|
self.logger.addHandler(console_handler)
|
|
|
|
def add_file_handler(self):
|
|
"""
|
|
Add file output of log.
|
|
"""
|
|
today_date = datetime.now().strftime("%Y%m%d")
|
|
filename = f"vt_{today_date}.log"
|
|
log_path = get_folder_path("log")
|
|
file_path = log_path.joinpath(filename)
|
|
|
|
file_handler = logging.FileHandler(
|
|
file_path, mode="w", encoding="utf8"
|
|
)
|
|
file_handler.setLevel(self.level)
|
|
file_handler.setFormatter(self.formatter)
|
|
self.logger.addHandler(file_handler)
|
|
|
|
def register_event(self):
|
|
""""""
|
|
self.event_engine.register(EVENT_LOG, self.process_log_event)
|
|
|
|
def process_log_event(self, event: Event):
|
|
"""
|
|
Output log event data with logging function.
|
|
"""
|
|
log = event.data
|
|
self.logger.log(log.level, log.msg)
|
|
|
|
|
|
class OmsEngine(BaseEngine):
|
|
"""
|
|
Provides order management system function for VN Trader.
|
|
"""
|
|
|
|
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
|
""""""
|
|
super(OmsEngine, self).__init__(main_engine, event_engine, "oms")
|
|
|
|
self.ticks = {}
|
|
self.orders = {}
|
|
self.trades = {}
|
|
self.positions = {}
|
|
self.accounts = {}
|
|
self.contracts = {}
|
|
|
|
self.active_orders = {}
|
|
|
|
self.add_function()
|
|
self.register_event()
|
|
|
|
def add_function(self):
|
|
"""Add query function to main engine."""
|
|
self.main_engine.get_tick = self.get_tick
|
|
self.main_engine.get_order = self.get_order
|
|
self.main_engine.get_trade = self.get_trade
|
|
self.main_engine.get_position = self.get_position
|
|
self.main_engine.get_account = self.get_account
|
|
self.main_engine.get_contract = self.get_contract
|
|
self.main_engine.get_all_ticks = self.get_all_ticks
|
|
self.main_engine.get_all_orders = self.get_all_orders
|
|
self.main_engine.get_all_trades = self.get_all_trades
|
|
self.main_engine.get_all_positions = self.get_all_positions
|
|
self.main_engine.get_all_accounts = self.get_all_accounts
|
|
self.main_engine.get_all_contracts = self.get_all_contracts
|
|
self.main_engine.get_all_active_orders = self.get_all_active_orders
|
|
|
|
def register_event(self):
|
|
""""""
|
|
self.event_engine.register(EVENT_TICK, self.process_tick_event)
|
|
self.event_engine.register(EVENT_ORDER, self.process_order_event)
|
|
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
|
|
self.event_engine.register(EVENT_POSITION, self.process_position_event)
|
|
self.event_engine.register(EVENT_ACCOUNT, self.process_account_event)
|
|
self.event_engine.register(EVENT_CONTRACT, self.process_contract_event)
|
|
|
|
def process_tick_event(self, event: Event):
|
|
""""""
|
|
tick = event.data
|
|
self.ticks[tick.vt_symbol] = tick
|
|
|
|
def process_order_event(self, event: Event):
|
|
""""""
|
|
order = event.data
|
|
self.orders[order.vt_orderid] = order
|
|
|
|
# If order is active, then update data in dict.
|
|
if order.is_active():
|
|
self.active_orders[order.vt_orderid] = order
|
|
# Otherwise, pop inactive order from in dict
|
|
elif order.vt_orderid in self.active_orders:
|
|
self.active_orders.pop(order.vt_orderid)
|
|
|
|
def process_trade_event(self, event: Event):
|
|
""""""
|
|
trade = event.data
|
|
self.trades[trade.vt_tradeid] = trade
|
|
|
|
def process_position_event(self, event: Event):
|
|
""""""
|
|
position = event.data
|
|
self.positions[position.vt_positionid] = position
|
|
|
|
def process_account_event(self, event: Event):
|
|
""""""
|
|
account = event.data
|
|
self.accounts[account.vt_accountid] = account
|
|
|
|
def process_contract_event(self, event: Event):
|
|
""""""
|
|
contract = event.data
|
|
self.contracts[contract.vt_symbol] = contract
|
|
|
|
def get_tick(self, vt_symbol):
|
|
"""
|
|
Get latest market tick data by vt_symbol.
|
|
"""
|
|
return self.ticks.get(vt_symbol, None)
|
|
|
|
def get_order(self, vt_orderid):
|
|
"""
|
|
Get latest order data by vt_orderid.
|
|
"""
|
|
return self.orders.get(vt_orderid, None)
|
|
|
|
def get_trade(self, vt_tradeid):
|
|
"""
|
|
Get trade data by vt_tradeid.
|
|
"""
|
|
return self.trades.get(vt_tradeid, None)
|
|
|
|
def get_position(self, vt_positionid):
|
|
"""
|
|
Get latest position data by vt_positionid.
|
|
"""
|
|
return self.positions.get(vt_positionid, None)
|
|
|
|
def get_account(self, vt_accountid):
|
|
"""
|
|
Get latest account data by vt_accountid.
|
|
"""
|
|
return self.accounts.get(vt_accountid, None)
|
|
|
|
def get_contract(self, vt_symbol):
|
|
"""
|
|
Get contract data by vt_symbol.
|
|
"""
|
|
return self.contracts.get(vt_symbol, None)
|
|
|
|
def get_all_ticks(self):
|
|
"""
|
|
Get all tick data.
|
|
"""
|
|
return list(self.ticks.values())
|
|
|
|
def get_all_orders(self):
|
|
"""
|
|
Get all order data.
|
|
"""
|
|
return list(self.orders.values())
|
|
|
|
def get_all_trades(self):
|
|
"""
|
|
Get all trade data.
|
|
"""
|
|
return list(self.trades.values())
|
|
|
|
def get_all_positions(self):
|
|
"""
|
|
Get all position data.
|
|
"""
|
|
return list(self.positions.values())
|
|
|
|
def get_all_accounts(self):
|
|
"""
|
|
Get all account data.
|
|
"""
|
|
return list(self.accounts.values())
|
|
|
|
def get_all_contracts(self):
|
|
"""
|
|
Get all contract data.
|
|
"""
|
|
return list(self.contracts.values())
|
|
|
|
def get_all_active_orders(self, vt_symbol: str = ""):
|
|
"""
|
|
Get all active orders by vt_symbol.
|
|
|
|
If vt_symbol is empty, return all active orders.
|
|
"""
|
|
if not vt_symbol:
|
|
return list(self.active_orders.values())
|
|
else:
|
|
active_orders = [
|
|
order
|
|
for order in self.active_orders.values()
|
|
if order.vt_symbol == vt_symbol
|
|
]
|
|
return active_orders
|
|
|
|
|
|
class EmailEngine(BaseEngine):
|
|
"""
|
|
Provides email sending function for VN Trader.
|
|
"""
|
|
|
|
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
|
""""""
|
|
super(EmailEngine, self).__init__(main_engine, event_engine, "email")
|
|
|
|
self.thread = Thread(target=self.run)
|
|
self.queue = Queue()
|
|
self.active = False
|
|
|
|
self.main_engine.send_email = self.send_email
|
|
|
|
def send_email(self, subject: str, content: str, receiver: str = ""):
|
|
""""""
|
|
# Start email engine when sending first email.
|
|
if not self.active:
|
|
self.start()
|
|
|
|
# Use default receiver if not specified.
|
|
if not receiver:
|
|
receiver = SETTINGS["email.receiver"]
|
|
|
|
msg = EmailMessage()
|
|
msg["From"] = SETTINGS["email.sender"]
|
|
msg["To"] = SETTINGS["email.receiver"]
|
|
msg["Subject"] = subject
|
|
msg.set_content(content)
|
|
|
|
self.queue.put(msg)
|
|
|
|
def run(self):
|
|
""""""
|
|
while self.active:
|
|
try:
|
|
msg = self.queue.get(block=True, timeout=1)
|
|
|
|
with smtplib.SMTP_SSL(
|
|
SETTINGS["email.server"], SETTINGS["email.port"]
|
|
) as smtp:
|
|
smtp.login(
|
|
SETTINGS["email.username"], SETTINGS["email.password"]
|
|
)
|
|
smtp.send_message(msg)
|
|
except Empty:
|
|
pass
|
|
|
|
def start(self):
|
|
""""""
|
|
self.active = True
|
|
self.thread.start()
|
|
|
|
def close(self):
|
|
""""""
|
|
if not self.active:
|
|
return
|
|
|
|
self.active = False
|
|
self.thread.join()
|