diff --git a/vnpy/trader/constant.py b/vnpy/trader/constant.py index f05fff98..c5adb95b 100644 --- a/vnpy/trader/constant.py +++ b/vnpy/trader/constant.py @@ -5,6 +5,13 @@ General constant string used in VN Trader. from enum import Enum +class Color(Enum): + """ Kline color """ + RED = 'Red' + BLUE = 'Blue' + EQUAL = 'Equal' + + class Direction(Enum): """ Direction of order/trade/position. @@ -141,7 +148,9 @@ class Interval(Enum): """ Interval of bar data. """ + SECOND = "1s" MINUTE = "1m" HOUR = "1h" DAILY = "d" WEEKLY = "w" + RENKO = 'renko' diff --git a/vnpy/trader/database/database.py b/vnpy/trader/database/database.py index c6920158..fba26938 100644 --- a/vnpy/trader/database/database.py +++ b/vnpy/trader/database/database.py @@ -24,7 +24,8 @@ class BaseDatabaseManager(ABC): exchange: "Exchange", interval: "Interval", start: datetime, - end: datetime + end: datetime, + **kwargs ) -> Sequence["BarData"]: pass @@ -34,7 +35,8 @@ class BaseDatabaseManager(ABC): symbol: str, exchange: "Exchange", start: datetime, - end: datetime + end: datetime, + **kwargs ) -> Sequence["TickData"]: pass diff --git a/vnpy/trader/database/database_mongo.py b/vnpy/trader/database/database_mongo.py index e8509478..4d7da15a 100644 --- a/vnpy/trader/database/database_mongo.py +++ b/vnpy/trader/database/database_mongo.py @@ -112,6 +112,9 @@ class DbTickData(Document): symbol: str = StringField() exchange: str = StringField() datetime: datetime = DateTimeField() + date: str = StringField() + time: str = StringField() + trading_day: str = StringField() name: str = StringField() volume: float = FloatField() @@ -170,6 +173,9 @@ class DbTickData(Document): db_tick.symbol = tick.symbol db_tick.exchange = tick.exchange.value db_tick.datetime = tick.datetime + db_tick.date = tick.date + db_tick.time = tick.time + db_tick.trading_day = tick.trading_day db_tick.name = tick.name db_tick.volume = tick.volume db_tick.open_interest = tick.open_interest @@ -218,6 +224,9 @@ class DbTickData(Document): symbol=self.symbol, exchange=Exchange(self.exchange), datetime=self.datetime, + date=self.date, + time=self.time, + trading_day=self.trading_day, name=self.name, volume=self.volume, open_interest=self.open_interest, @@ -269,6 +278,7 @@ class MongoManager(BaseDatabaseManager): interval: Interval, start: datetime, end: datetime, + **kwargs ) -> Sequence[BarData]: s = DbBarData.objects( symbol=symbol, @@ -281,7 +291,7 @@ class MongoManager(BaseDatabaseManager): return data def load_tick_data( - self, symbol: str, exchange: Exchange, start: datetime, end: datetime + self, symbol: str, exchange: Exchange, start: datetime, end: datetime, **kwargs ) -> Sequence[TickData]: s = DbTickData.objects( symbol=symbol, diff --git a/vnpy/trader/database/database_sql.py b/vnpy/trader/database/database_sql.py index a627f3ce..4b076704 100644 --- a/vnpy/trader/database/database_sql.py +++ b/vnpy/trader/database/database_sql.py @@ -337,6 +337,7 @@ class SqlManager(BaseDatabaseManager): interval: Interval, start: datetime, end: datetime, + **kwargs ) -> Sequence[BarData]: s = ( self.class_bar.select() @@ -353,7 +354,7 @@ class SqlManager(BaseDatabaseManager): return data def load_tick_data( - self, symbol: str, exchange: Exchange, start: datetime, end: datetime + self, symbol: str, exchange: Exchange, start: datetime, end: datetime, **kwargs ) -> Sequence[TickData]: s = ( self.class_tick.select() diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index a9d2205b..db8df656 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -33,6 +33,9 @@ from .object import ( from .setting import SETTINGS from .utility import get_folder_path, TRADER_DIR +# 专有的logger文件 +from .util_logger import setup_logger + class MainEngine: """ @@ -239,6 +242,37 @@ class BaseEngine(ABC): self.event_engine = event_engine self.engine_name = engine_name + self.logger = None + + def create_logger(self, logger_name: str = 'base_engine'): + """ + 创建engine独有的日志 + :param logger_name: 日志名,缺省为engine的名称 + :return: + """ + log_path = get_folder_path("log") + log_filename = os.path.abspath(os.path.join(log_path, logger_name)) + print(u'create logger:{}'.format(log_filename)) + self.logger = setup_logger(file_name=log_filename, name=logger_name, + log_level=SETTINGS.get('log.level', logging.DEBUG)) + + def write_log(self, msg: str, source: str = "", level: int = logging.DEBUG): + """ + 写入日志 + :param msg: 日志内容 + :param source: 来源 + :param level: 日志级别 + :return: + """ + if self.logger: + if len(source) > 0: + msg = f'[{source}]{msg}' + self.logger.log(level, msg) + else: + log = LogData(msg=msg, level=level) + event = Event(EVENT_LOG, log) + self.event_engine.put(event) + def close(self): """""" pass diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index 4ce95cff..e46adf1d 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from typing import Any, Sequence from copy import copy +from logging import INFO from vnpy.event import Event, EventEngine from .event import ( @@ -140,11 +141,11 @@ class BaseGateway(ABC): """ self.on_event(EVENT_CONTRACT, contract) - def write_log(self, msg: str): + def write_log(self, msg: str, level: int = INFO): """ Write a log event from gateway. """ - log = LogData(msg=msg, gateway_name=self.gateway_name) + log = LogData(msg=msg, level=level, gateway_name=self.gateway_name) self.on_log(log) @abstractmethod diff --git a/vnpy/trader/object.py b/vnpy/trader/object.py index 42d25387..01f5738b 100644 --- a/vnpy/trader/object.py +++ b/vnpy/trader/object.py @@ -103,6 +103,20 @@ class BarData(BaseData): self.vt_symbol = f"{self.symbol}.{self.exchange.value}" +@dataclass +class RenkoBarData(BarData): + """ + Renko bar data of a certain trading period. + """ + seconds: int = 0 # 当前Bar的秒数(针对RenkoBar) + high_seconds: int = -1 # 当前Bar的上限秒数 + low_seconds: int = -1 # 当前bar的下限秒数 + height: float = 3 # 当前Bar的高度限制(针对RenkoBar和RangeBar类) + up_band: float = 0 # 高位区域的基线 + down_band: float = 0 # 低位区域的基线 + low_time = None # 最后一次进入低位区域的时间 + high_time = None # 最后一次进入高位区域的时间 + @dataclass class OrderData(BaseData): """ diff --git a/vnpy/trader/util_logger.py b/vnpy/trader/util_logger.py index d87e1bdf..6c9d887f 100644 --- a/vnpy/trader/util_logger.py +++ b/vnpy/trader/util_logger.py @@ -159,11 +159,19 @@ class MultiprocessHandler(logging. FileHandler): self.handleError(record) -def setup_logger(filename, name=None, debug=False, force=False, backtesing=False): +def setup_logger(file_name: str, + name: str = None, + log_level: int = logging.DEBUG, + force: bool = False, + backtesing: bool = False): """ 设置日志文件,包括路径 自动在后面添加 "_日期.log" - :param logger_file_name: + :param file_name: 日志文件名 + :param name: logger 名 + :param log_level: 日志级别 + :param force: 是否强制更新日志名称 + :param backtesing: 是否为回测(回测输出的格式不同) :return: """ @@ -171,21 +179,21 @@ def setup_logger(filename, name=None, debug=False, force=False, backtesing=False global _fileHandler global _logger_filename - if _logger is not None and _logger_filename == filename and not force: + if _logger is not None and _logger_filename == file_name and not force: return _logger - if _logger_filename != filename or force: + if _logger_filename != file_name or force: if force: - _logger_filename = filename + _logger_filename = file_name # 定义日志输出格式 fmt = logging.Formatter(RECORD_FORMAT if not backtesing else BACKTEST_FORMAT) if name is None: - names = filename.replace('.log', '').split('/') + names = file_name.replace('.log', '').split('/') name = names[-1] logger = logging.getLogger(name) - if debug: + if log_level == logging.DEBUG: logger.setLevel(logging.DEBUG) stream_handler = logging.StreamHandler(sys.stdout) stream_handler.setLevel(logging.DEBUG) @@ -193,26 +201,26 @@ def setup_logger(filename, name=None, debug=False, force=False, backtesing=False if not logger.hasHandlers(): logger.addHandler(stream_handler) - fileHandler = MultiprocessHandler(filename, encoding='utf8', interval='D') - if debug: - fileHandler.setLevel(logging.DEBUG) - else: - fileHandler.setLevel(logging.WARNING) - + # 创建文件日志 + fileHandler = MultiprocessHandler(file_name, encoding='utf8', interval='D') + fileHandler.setLevel(log_level) fileHandler.setFormatter(fmt) logger.addHandler(fileHandler) - if debug: - logger.setLevel(logging.DEBUG) - else: - logger.setLevel(logging.WARNING) + # 设置logger的级别 + logger.setLevel(log_level) return logger return _logger -def get_logger(name=None): +def get_logger(name: str = None): + """ + 根据name获取logger + :param name: + :return: + """ global _logger if _logger is None: @@ -227,6 +235,7 @@ def get_logger(name=None): # -------------------测试代码------------ def single_func(para): + """ 测试单进程""" setup_logger('logs/MyLog{}'.format(para)) logger = get_logger() if para > 5: @@ -240,6 +249,7 @@ def single_func(para): def multi_func(): + """测试多进程""" # 启动多进程 pool = multiprocessing.Pool(multiprocessing.cpu_count()) setup_logger('logs/MyLog') diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index f805f0e4..94432105 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -19,7 +19,6 @@ import talib from .object import BarData, TickData from .constant import Exchange, Interval - log_formatter = logging.Formatter('[%(asctime)s] %(message)s') @@ -74,6 +73,30 @@ def get_underlying_symbol(symbol: str): return underlying_symbol.group(1) +@lru_cache() +def get_stock_exchange(code, vn=True): + """根据股票代码,获取交易所""" + # vn:取EXCHANGE_SSE 和 EXCHANGE_SZSE + code = str(code) + if len(code) < 6: + return '' + + market_id = 0 # 缺省深圳 + code = str(code) + if code[0] in ['5', '6', '9'] or code[:3] in ["009", "126", "110", "201", "202", "203", "204"]: + market_id = 1 # 上海 + try: + from vnpy.trader.constant import Exchange + if vn: + return Exchange.SSE.value if market_id == 1 else Exchange.SZSE.value + else: + return 'XSHG' if market_id == 1 else 'XSHE' + except Exception as ex: + print(u'加载数据异常:{}'.format(str(ex))) + + return '' + + @lru_cache() def get_full_symbol(symbol: str): """ @@ -140,6 +163,11 @@ def _get_trader_dir(temp_name: str): """ Get path where trader is running in. """ + # by incenselee + # 原方法,当前目录必须自建.vntrader子目录,否则在用户得目录下创建 + # 为兼容多账号管理,取消此方法。 + return Path.cwd(), Path.cwd() + cwd = Path.cwd() temp_path = cwd.joinpath(temp_name) @@ -161,6 +189,7 @@ def _get_trader_dir(temp_name: str): TRADER_DIR, TEMP_DIR = _get_trader_dir(".vntrader") sys.path.append(str(TRADER_DIR)) +print(f'sys.path append: {str(TRADER_DIR)}') def get_file_path(filename: str):