From 2e8513fea6158bfb047fed6ccfa20e94a682c954 Mon Sep 17 00:00:00 2001 From: msincenselee Date: Sun, 1 Dec 2019 00:28:17 +0800 Subject: [PATCH] =?UTF-8?q?[=E6=96=B0=E5=8A=9F=E8=83=BD]=20=E5=A4=9A?= =?UTF-8?q?=E8=BF=9B=E7=A8=8B=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vnpy/amqp/base.py | 14 +- vnpy/amqp/consumer.py | 29 ++-- vnpy/amqp/producer.py | 30 ++-- vnpy/amqp/test01_receiver.py | 3 +- vnpy/amqp/test02_task.py | 2 +- vnpy/amqp/test02_woker.py | 2 - vnpy/amqp/test03_subscriber.py | 3 +- vnpy/amqp/test06_rpc_client.py | 8 +- vnpy/amqp/test06_rpc_server.py | 20 +-- vnpy/amqp/test07_rpc_client.py | 20 +-- vnpy/app/cta_strategy/engine.py | 5 +- vnpy/data/tdx/tdx_common.py | 48 +++--- vnpy/data/tdx/tdx_future_data.py | 183 +++++++++++---------- vnpy/data/tdx/tdx_stock_data.py | 85 +++++----- vnpy/event/engine.py | 2 +- vnpy/trader/util_logger.py | 267 +++++++++++++++++++++++++++++++ vnpy/trader/utility.py | 2 + 17 files changed, 507 insertions(+), 216 deletions(-) create mode 100644 vnpy/trader/util_logger.py diff --git a/vnpy/amqp/base.py b/vnpy/amqp/base.py index 82d4ed4d..c3797ac5 100644 --- a/vnpy/amqp/base.py +++ b/vnpy/amqp/base.py @@ -1,13 +1,11 @@ # encoding: UTF-8 - import pika -class base_broker(): +class base_broker(): def __init__(self, host='localhost', port=5672, user='guest', password='guest', channel_number=1): """ - :param host: 连接rabbitmq的服务器地址(或者群集地址) :param port: 端口 :param user: 用户名 @@ -26,11 +24,9 @@ class base_broker(): # 创建连接 self.connection = pika.BlockingConnection( - pika.ConnectionParameters(host=self.host, port=self.port, - credentials=self.credentials, - heartbeat=0, socket_timeout=5, - ) - ) + pika.ConnectionParameters(host=self.host, port=self.port, + credentials=self.credentials, + heartbeat=0, socket_timeout=5)) # 创建一个频道,或者指定频段数字编号 self.channel = self.connection.channel( @@ -43,7 +39,7 @@ class base_broker(): """ try: self.connection.close() - except: + except Exception: pass self.connection = pika.BlockingConnection( diff --git a/vnpy/amqp/consumer.py b/vnpy/amqp/consumer.py index 95204de0..13a1ae90 100644 --- a/vnpy/amqp/consumer.py +++ b/vnpy/amqp/consumer.py @@ -1,14 +1,13 @@ # encoding: UTF-8 - # 消息消费者类(集合) +# 消息消费者类(集合) import json import pika import random import traceback from vnpy.amqp.base import base_broker -from threading import Thread -######### 模式1:接收者 ######### +# 模式1:接收者 class receiver(base_broker): def __init__(self, host='localhost', port=5672, user='admin', password='admin', exchange='x', @@ -48,7 +47,8 @@ class receiver(base_broker): print(e) self.start() -######### 模式2:(执行)接收者######### + +# 模式2:(执行)接收者 class worker(base_broker): def __init__(self, host='localhost', port=5672, user='admin', password='admin', exchange='x_work_queue', @@ -63,7 +63,7 @@ class worker(base_broker): exchange_type='direct', durable=True) - self.queue = self.channel.queue_declare(queue=queue,durable=True).method.queue + self.queue = self.channel.queue_declare(queue=queue, durable=True).method.queue print('worker use exchange:{},queue:{}'.format(exchange, self.queue)) self.channel.queue_bind(queue=self.queue, exchange=exchange, routing_key=self.routing_key) # 队列名采用服务端分配的临时队列 @@ -91,7 +91,8 @@ class worker(base_broker): print(str(e)) traceback.print_exc() -######### 模式3:发布 / 订阅(Pub/Sub)模式, 订阅者######### + +# 模式3:发布 / 订阅(Pub/Sub)模式, 订阅者 class subscriber(base_broker): def __init__(self, host='localhost', port=5672, user='admin', password='admin', @@ -115,7 +116,7 @@ class subscriber(base_broker): # 缺省回调函数地址 self.cb_func = self.callback - def set_callback(self,cb_func): + def set_callback(self, cb_func): self.cb_func = cb_func def callback(self, chan, method_frame, _header_frame, body, userdata=None): @@ -134,9 +135,9 @@ class subscriber(base_broker): except Exception as ex: print('subscriber exception:{}'.format(str(ex))) traceback.print_exc() - #self.start() -######### 模式4: 路由模式 ######### + +# 模式4: 路由模式 class subscriber_routing(base_broker): def __init__(self, host='localhost', port=5672, user='admin', password='admin', @@ -174,7 +175,7 @@ class subscriber_routing(base_broker): self.start() -######### 模式5:主题模式 ######### +# 模式5:主题模式 class subscriber_topic(base_broker): def __init__(self, host='localhost', port=5672, user='admin', password='admin', @@ -211,7 +212,8 @@ class subscriber_topic(base_broker): print(e) self.start() -######### 模式6:RPC模式 (服务端) ######### + +# 模式6:RPC模式 (服务端) class rpc_server(base_broker): # 接收: # exchange: x_rpc @@ -323,10 +325,11 @@ class rpc_server(base_broker): print(e) self.start() + if __name__ == '__main__': import sys - if len(sys.argv) >=2: + if len(sys.argv) >= 2: print(sys.argv) from time import sleep c = subscriber(user='admin', password='admin') @@ -335,5 +338,3 @@ if __name__ == '__main__': while True: sleep(1) - - diff --git a/vnpy/amqp/producer.py b/vnpy/amqp/producer.py index 3726cb06..8891eded 100644 --- a/vnpy/amqp/producer.py +++ b/vnpy/amqp/producer.py @@ -1,15 +1,15 @@ # encoding: UTF-8 # 消息生产者类(集合) -import sys + import json import pika -import random import traceback from threading import Thread from uuid import uuid1 from vnpy.amqp.base import base_broker -######### 模式1:发送者 ######### + +# 模式1:发送者 class sender(base_broker): def __init__(self, host='localhost', port=5672, user='admin', password='admin', exchange='x', queue_name='', routing_key='default'): @@ -61,7 +61,8 @@ class sender(base_broker): def exit(self): self.connection.close() -######### 模式2:工作队列,任务发布者 ######### + +# 模式2:工作队列,任务发布者 class task_creator(base_broker): def __init__(self, host='localhost', port=5672, user='admin', password='admin', channel_number=1, queue_name='task_queue', routing_key='default', @@ -76,7 +77,7 @@ class task_creator(base_broker): # 通过channel,创建/使用一个queue。 queue = self.channel.queue_declare(self.queue_name, durable=True).method.queue - print(u'create/use queue:{}') + print(f'create/use queue:{queue}') # 通过channel,创建/使用一个网关 # exchange_type: direct # passive: 只是检查其是否存在 @@ -110,7 +111,8 @@ class task_creator(base_broker): def exit(self): self.connection.close() -######### 3、发布 / 订阅(Pub/Sub)模式,发布者 ######### + +# 3、发布 / 订阅(Pub/Sub)模式,发布者 class publisher(base_broker): def __init__(self, host='localhost', port=5672, user='admin', password='admin', @@ -167,11 +169,12 @@ class publisher(base_broker): def exit(self): self.connection.close() -######### 4、路由模式:发布者 ######### + +# 4、路由模式:发布者 class publisher_routing(base_broker): def __init__(self, host='localhost', port=5672, user='admin', password='admin', - channel_number=1, queue_name='', routing_key='default', exchange='x_direct'): + channel_number=1, queue_name='', routing_key='default', exchange='x_direct'): super().__init__(host, port, user, password, channel_number) self.queue_name = queue_name @@ -209,7 +212,8 @@ class publisher_routing(base_broker): def exit(self): self.connection.close() -######### 5、主题模式:发布者 ######### + +# 5、主题模式:发布者 class publisher_topic(base_broker): def __init__(self, host='localhost', port=5672, user='admin', password='admin', @@ -252,7 +256,7 @@ class publisher_topic(base_broker): self.connection.close() -######### 6、RPC模式(调用者) ######### +# 6、RPC模式(调用者) class rpc_client(base_broker): # 发送: # exchange: x_rpc @@ -269,7 +273,6 @@ class rpc_client(base_broker): super().__init__(host, port, user, password, channel_number=1) self.exchange = exchange - #self.queue_name = queue_name self.routing_key = routing_key # 通过channel,创建/使用一个网关 @@ -318,7 +321,7 @@ class rpc_client(base_broker): cb = self.cb_dict.pop(props.correlation_id, None) if cb: try: - cb(body) + cb(body) except Exception as ex: print('on_respone exception when call cb.{}'.format(str(ex))) traceback.print_exc() @@ -377,9 +380,10 @@ class rpc_client(base_broker): self.connection.close() if self.thread: self.thread.join() - except: + except Exception: pass + if __name__ == '__main__': import datetime import time diff --git a/vnpy/amqp/test01_receiver.py b/vnpy/amqp/test01_receiver.py index e6d3f421..e9410683 100644 --- a/vnpy/amqp/test01_receiver.py +++ b/vnpy/amqp/test01_receiver.py @@ -1,8 +1,7 @@ from vnpy.amqp.consumer import receiver -if __name__ == '__main__': - import sys +if __name__ == '__main__': from time import sleep c = receiver(user='admin', password='admin') diff --git a/vnpy/amqp/test02_task.py b/vnpy/amqp/test02_task.py index 37365973..75326da0 100644 --- a/vnpy/amqp/test02_task.py +++ b/vnpy/amqp/test02_task.py @@ -12,7 +12,7 @@ if __name__ == '__main__': while True: time.sleep(10) mission = {} - mission.update({'id':str(uuid1())}) + mission.update({'id': str(uuid1())}) mission.update({'templateName': u'TWAP 时间加权平均'}) mission.update({'direction': Direction.LONG}) mission.update({'vtSymbol': '518880'}) diff --git a/vnpy/amqp/test02_woker.py b/vnpy/amqp/test02_woker.py index b14c81a0..fff58122 100644 --- a/vnpy/amqp/test02_woker.py +++ b/vnpy/amqp/test02_woker.py @@ -4,8 +4,6 @@ from vnpy.amqp.consumer import worker if __name__ == '__main__': - import sys - from time import sleep c = worker(host='192.168.0.202', user='admin', password='admin') c.start() diff --git a/vnpy/amqp/test03_subscriber.py b/vnpy/amqp/test03_subscriber.py index b33951fe..3697e8a8 100644 --- a/vnpy/amqp/test03_subscriber.py +++ b/vnpy/amqp/test03_subscriber.py @@ -2,9 +2,8 @@ from vnpy.amqp.consumer import subscriber -if __name__ == '__main__': - import sys +if __name__ == '__main__': from time import sleep c = subscriber(user='admin', password='admin', exchange='x_fanout_md_tick') diff --git a/vnpy/amqp/test06_rpc_client.py b/vnpy/amqp/test06_rpc_client.py index 432fe744..ac090432 100644 --- a/vnpy/amqp/test06_rpc_client.py +++ b/vnpy/amqp/test06_rpc_client.py @@ -5,13 +5,15 @@ import json import random from vnpy.amqp.producer import rpc_client + def cb_function(*args): print('resp call back') for arg in args: print(u'{}'.format(arg)) + if __name__ == '__main__': - import datetime + import time c = rpc_client(host='localhost', user='admin', password='admin') @@ -25,9 +27,9 @@ if __name__ == '__main__': params.update({'p1': counter}) mission.update({'params': params}) msg = json.dumps(mission) - print(u'[x] rpc call :{}'.format(msg)) + print(f'[x] rpc call :{msg}') - c.call(msg,str(uuid1()), cb_function) + c.call(msg, str(uuid1()), cb_function) counter += 1 if counter > 100: diff --git a/vnpy/amqp/test06_rpc_server.py b/vnpy/amqp/test06_rpc_server.py index 79863ccb..12aa6611 100644 --- a/vnpy/amqp/test06_rpc_server.py +++ b/vnpy/amqp/test06_rpc_server.py @@ -1,22 +1,22 @@ -import os, sys, copy -# 将repostory的目录i,作为根目录,添加到系统环境中。 -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) -sys.path.append(ROOT_PATH) +# encoding: UTF-8 + +import copy +import argparse +from .consumer import rpc_server routing_key = 'default' -from vnpy.amqp.consumer import rpc_server -import argparse +def test_func01(p1, p2, p3): + print(f'test_func01:{p1} {p2} {p3}') + return p1 + p2 + p3 -def test_func01(p1,p2,p3): - print('test_func01:', p1, p2, p3) - return p1+p2+p3 def test_func02(p1, p2=0): - print('test_func02:', p1, p2) + print(f'test_func02:{p1} {p2}') return str(p1 + p2) + def get_strategy_names(): print(u'{}'.format(routing_key)) return ['stratege_name_01', 'strategy_name_02'] diff --git a/vnpy/amqp/test07_rpc_client.py b/vnpy/amqp/test07_rpc_client.py index e414c192..6ccfc695 100644 --- a/vnpy/amqp/test07_rpc_client.py +++ b/vnpy/amqp/test07_rpc_client.py @@ -1,23 +1,20 @@ -import os, sys -# 将repostory的目录i,作为根目录,添加到系统环境中。 -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) -sys.path.append(ROOT_PATH) +# encoding: UTF-8 -from vnpy.amqp.producer import rpc_client +from .producer import rpc_client from uuid import uuid1 import json -import random import argparse + def cb_function(*args): print('resp call back') for arg in args: - if isinstance(arg,bytes): + if isinstance(arg, bytes): print(u'{}'.format(arg.decode('utf-8'))) else: print(u'{}'.format(arg)) -from vnpy.trader.vtConstant import * + if __name__ == '__main__': # 参数分析 parser = argparse.ArgumentParser() @@ -32,16 +29,15 @@ if __name__ == '__main__': help='rabbit mq password') parser.add_argument('-x', '--exchange', type=str, default='exchange', help='rabbit mq exchange') - parser.add_argument('-q', '--queue', type=str, default='queue', - help='rabbit mq queue') parser.add_argument('-r', '--routing_key', type=str, default='default', help='rabbit mq routing_key') args = parser.parse_args() - import datetime import time - c = rpc_client(host=args.host, port=args.port, user=args.user, password=args.password, exchange=args.exchange, queue_name=args.queue, routing_key=args.routing_key) + c = rpc_client(host=args.host, port=args.port, user=args.user, + password=args.password, exchange=args.exchange, + routing_key=args.routing_key) counter = 0 while True: diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index 01f16c05..68fd0457 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -9,6 +9,7 @@ from typing import Any, Callable from datetime import datetime, timedelta from concurrent.futures import ThreadPoolExecutor from copy import copy +from logging import INFO, ERROR, DEBUG from vnpy.event import Event, EventEngine from vnpy.trader.engine import BaseEngine, MainEngine @@ -893,14 +894,14 @@ class CtaEngine(BaseEngine): event = Event(EVENT_CTA_STRATEGY, data) self.event_engine.put(event) - def write_log(self, msg: str, strategy: CtaTemplate = None): + def write_log(self, msg: str, strategy: CtaTemplate = None, level: int = INFO): """ Create cta engine log event. """ if strategy: msg = f"{strategy.strategy_name}: {msg}" - log = LogData(msg=msg, gateway_name="CtaStrategy") + log = LogData(msg=msg, gateway_name="CtaStrategy", level=level) event = Event(type=EVENT_CTA_LOG, data=log) self.event_engine.put(event) diff --git a/vnpy/data/tdx/tdx_common.py b/vnpy/data/tdx/tdx_common.py index 16f67d62..c453a082 100644 --- a/vnpy/data/tdx/tdx_common.py +++ b/vnpy/data/tdx/tdx_common.py @@ -2,6 +2,7 @@ from functools import lru_cache + @lru_cache() def get_tdx_market_code(code): # 获取通达信股票的market code @@ -27,36 +28,33 @@ def get_tdx_market_code(code): # 10 - 季K 线 # 11 - 年K 线 PERIOD_MAPPING = {} -PERIOD_MAPPING['1min'] = 8 -PERIOD_MAPPING['5min'] = 0 -PERIOD_MAPPING['15min'] = 1 -PERIOD_MAPPING['30min'] = 2 -PERIOD_MAPPING['1hour'] = 3 -PERIOD_MAPPING['1day'] = 4 -PERIOD_MAPPING['1week'] = 5 +PERIOD_MAPPING['1min'] = 8 +PERIOD_MAPPING['5min'] = 0 +PERIOD_MAPPING['15min'] = 1 +PERIOD_MAPPING['30min'] = 2 +PERIOD_MAPPING['1hour'] = 3 +PERIOD_MAPPING['1day'] = 4 +PERIOD_MAPPING['1week'] = 5 PERIOD_MAPPING['1month'] = 6 - # 期货行情服务器清单 TDX_FUTURE_HOSTS = [ - {"ip": "112.74.214.43", "port": 7727, "name": "扩展市场深圳双线1"}, - {"ip": "120.24.0.77", "port": 7727, "name": "扩展市场深圳双线2"}, - {"ip": "47.107.75.159", "port": 7727, "name": "扩展市场深圳双线3"}, + {"ip": "112.74.214.43", "port": 7727, "name": "扩展市场深圳双线1"}, + {"ip": "120.24.0.77", "port": 7727, "name": "扩展市场深圳双线2"}, + {"ip": "47.107.75.159", "port": 7727, "name": "扩展市场深圳双线3"}, - {"ip": "113.105.142.136", "port": 443, "name": "扩展市场东莞主站"}, - {"ip": "113.105.142.133", "port": 443, "name": "港股期货东莞电信"}, + {"ip": "113.105.142.136", "port": 443, "name": "扩展市场东莞主站"}, + {"ip": "113.105.142.133", "port": 443, "name": "港股期货东莞电信"}, - {"ip": "119.97.185.5", "port": 7727, "name": "扩展市场武汉主站1"}, - {"ip": "119.97.185.7", "port": 7727, "name": "港股期货武汉主站1"}, - {"ip": "119.97.185.9", "port": 7727, "name": "港股期货武汉主站2"}, - {"ip": "59.175.238.38", "port": 7727, "name": "扩展市场武汉主站3"}, + {"ip": "119.97.185.5", "port": 7727, "name": "扩展市场武汉主站1"}, + {"ip": "119.97.185.7", "port": 7727, "name": "港股期货武汉主站1"}, + {"ip": "119.97.185.9", "port": 7727, "name": "港股期货武汉主站2"}, + {"ip": "59.175.238.38", "port": 7727, "name": "扩展市场武汉主站3"}, - {"ip": "202.103.36.71", "port": 443, "name": "扩展市场武汉主站2"}, - - {"ip": "47.92.127.181", "port": 7727, "name": "扩展市场北京主站"}, - {"ip": "106.14.95.149", "port": 7727, "name": "扩展市场上海双线"}, - {"ip": '218.80.248.229', 'port': 7721 ,"name":"备用服务器1"}, - {"ip": '124.74.236.94', 'port': 7721, "name": "备用服务器2"}, - {'ip': '58.246.109.27', 'port': 7721,"name": "备用服务器3"} - ] + {"ip": "202.103.36.71", "port": 443, "name": "扩展市场武汉主站2"}, + {"ip": "47.92.127.181", "port": 7727, "name": "扩展市场北京主站"}, + {"ip": "106.14.95.149", "port": 7727, "name": "扩展市场上海双线"}, + {"ip": '218.80.248.229', 'port': 7721, "name": "备用服务器1"}, + {"ip": '124.74.236.94', 'port': 7721, "name": "备用服务器2"}, + {'ip': '58.246.109.27', 'port': 7721, "name": "备用服务器3"}] diff --git a/vnpy/data/tdx/tdx_future_data.py b/vnpy/data/tdx/tdx_future_data.py index e0179bb9..2ba0b974 100644 --- a/vnpy/data/tdx/tdx_future_data.py +++ b/vnpy/data/tdx/tdx_future_data.py @@ -8,32 +8,37 @@ # - 1day在VNPY合成时不关心已经收到多少Bar, 所以影响也不大 # - 但其它分钟周期因为不好精确到每个品种, 修改后的freq可能有错 -from datetime import datetime, timezone, timedelta, time -import sys, os, csv, pickle, bz2, copy +import sys +import os +import pickle +import bz2 +import copy import json import traceback +from datetime import datetime, timedelta, time from logging import ERROR, INFO +from typing import Dict + from pandas import to_datetime from pytdx.exhq import TdxExHq_API from vnpy.trader.constant import Exchange from vnpy.trader.object import BarData -from vnpy.trader.utility import get_underlying_symbol, get_full_symbol, get_trading_date - -from vnpy.data.tdx.tdx_common import TDX_FUTURE_HOSTS, PERIOD_MAPPING +from vnpy.trader.utility import (get_underlying_symbol, get_full_symbol, get_trading_date) +from vnpy.data.tdx.tdx_common import (TDX_FUTURE_HOSTS, PERIOD_MAPPING) # 每个周期包含多少分钟 (估算值, 没考虑夜盘和10:15的影响) -NUM_MINUTE_MAPPING = {} -NUM_MINUTE_MAPPING['1min'] = 1 -NUM_MINUTE_MAPPING['5min'] = 5 -NUM_MINUTE_MAPPING['15min'] = 15 -NUM_MINUTE_MAPPING['30min'] = 30 -NUM_MINUTE_MAPPING['1hour'] = 60 -NUM_MINUTE_MAPPING['1day'] = 60*24 -NUM_MINUTE_MAPPING['1week'] = 60*24*7 -NUM_MINUTE_MAPPING['1month'] = 60*24*7*30 +NUM_MINUTE_MAPPING: Dict[str, int] = {} +NUM_MINUTE_MAPPING['1min'] = 1 +NUM_MINUTE_MAPPING['5min'] = 5 +NUM_MINUTE_MAPPING['15min'] = 15 +NUM_MINUTE_MAPPING['30min'] = 30 +NUM_MINUTE_MAPPING['1hour'] = 60 +NUM_MINUTE_MAPPING['1day'] = 60 * 24 +NUM_MINUTE_MAPPING['1week'] = 60 * 24 * 7 +NUM_MINUTE_MAPPING['1month'] = 60 * 24 * 7 * 30 Tdx_Vn_Exchange_Map = {} Tdx_Vn_Exchange_Map['47'] = Exchange.CFFEX @@ -41,19 +46,27 @@ Tdx_Vn_Exchange_Map['28'] = Exchange.CZCE Tdx_Vn_Exchange_Map['29'] = Exchange.DCE Tdx_Vn_Exchange_Map['30'] = Exchange.SHFE -Vn_Tdx_Exchange_Map = {v:k for k,v in Tdx_Vn_Exchange_Map.items()} +Vn_Tdx_Exchange_Map = {v: k for k, v in Tdx_Vn_Exchange_Map.items()} # 能源所与上期所,都归纳到 30 Vn_Tdx_Exchange_Map[Exchange.INE] = '30' -INIT_TDX_MARKET_MAP = {'URL9': 28,'WHL9':28,'ZCL9':28,'AL9':29,'BBL9':29,'BL9':29,'CL9':29,'CSL9':29,'EBL9':29,'EGL9':29,'FBL9':29,'IL9':29, -'JDL9':29,'JL9':29,'JML9':29,'LL9':29,'ML9':29,'PL9':29,'PPL9':29,'RRL9':29,'VL9':29,'YL9':29,'AGL9':30,'ALL9':30,'AUL9':30, -'BUL9':30,'CUL9':30,'FUL9':30,'HCL9':30,'NIL9':30,'NRL9':30,'PBL9':30,'RBL9':30,'RUL9':30,'SCL9':30,'SNL9':30,'SPL9':30,'SSL9':30,'WRL9':30, -'ZNL9':30,'APL9':28,'CFL9':28,'CJL9':28,'CYL9':28,'FGL9':28,'JRL9':28,'LRL9':28,'MAL9':28,'OIL9':28,'PML9':28,'RIL9':28,'RML9':28,'RSL9':28,'SFL9':28, -'SML9':28,'SRL9':28,'TAL9':28,'ICL9':47,'IFL9':47,'IHL9':47,'TFL9':47,'TL9':47,'TSL9':47} +INIT_TDX_MARKET_MAP = { + 'URL9': 28, 'WHL9': 28, 'ZCL9': 28, 'AL9': 29, 'BBL9': 29, 'BL9': 29, + 'CL9': 29, 'CSL9': 29, 'EBL9': 29, 'EGL9': 29, 'FBL9': 29, 'IL9': 29, + 'JDL9': 29, 'JL9': 29, 'JML9': 29, 'LL9': 29, 'ML9': 29, 'PL9': 29, + 'PPL9': 29, 'RRL9': 29, 'VL9': 29, 'YL9': 29, 'AGL9': 30, 'ALL9': 30, + 'AUL9': 30, 'BUL9': 30, 'CUL9': 30, 'FUL9': 30, 'HCL9': 30, 'NIL9': 30, + 'NRL9': 30, 'PBL9': 30, 'RBL9': 30, 'RUL9': 30, 'SCL9': 30, 'SNL9': 30, + 'SPL9': 30, 'SSL9': 30, 'WRL9': 30, 'ZNL9': 30, 'APL9': 28, 'CFL9': 28, + 'CJL9': 28, 'CYL9': 28, 'FGL9': 28, 'JRL9': 28, 'LRL9': 28, 'MAL9': 28, + 'OIL9': 28, 'PML9': 28, 'RIL9': 28, 'RML9': 28, 'RSL9': 28, 'SFL9': 28, + 'SML9': 28, 'SRL9': 28, 'TAL9': 28, 'ICL9': 47, 'IFL9': 47, 'IHL9': 47, + 'TFL9': 47, 'TL9': 47, 'TSL9': 47} # 常量 QSIZE = 500 ALL_MARKET_BEGIN_HOUR = 8 ALL_MARKET_END_HOUR = 16 + class TdxFutureData(object): # ---------------------------------------------------------------------- @@ -89,7 +102,7 @@ class TdxFutureData(object): # 创建api连接对象实例 try: - if self.api is None or self.connection_status == False: + if self.api is None or not self.connection_status: self.write_log(u'开始连接通达信行情服务器') self.api = TdxExHq_API(heartbeat=True, auto_retry=True, raise_exception=True) @@ -129,13 +142,13 @@ class TdxFutureData(object): with apix.connect(ip, port): if apix.get_instrument_count() > 10000: _timestamp = datetime.now() - __time1 - self.write_log('服务器{}:{},耗时:{}'.format(ip,port,_timestamp)) + self.write_log(f'服务器{ip}:{port},耗时:{_timestamp}') return _timestamp else: - self.write_log(u'该服务器IP {}无响应'.format(ip)) + self.write_log(f'该服务器IP {ip}无响应') return timedelta(9, 9, 0) - except: - self.write_error(u'tdx ping服务器,异常的响应{}'.format(ip)) + except Exception: + self.write_error(f'tdx ping服务器,异常的响应{ip}') return timedelta(9, 9, 0) # ---------------------------------------------------------------------- @@ -163,11 +176,11 @@ class TdxFutureData(object): # 取得所有的合约信息 num = self.api.get_instrument_count() - if not isinstance(num,int): + if not isinstance(num, int): return - all_contacts = sum([self.api.get_instrument_info((int(num / 500) - i) * 500, 500) for i in range(int(num / 500) + 1)],[]) - #[{"category":category,"market": int,"code":sting,"name":string,"desc":string},{}] + all_contacts = sum([self.api.get_instrument_info((int(num / 500) - i) * 500, 500) for i in range(int(num / 500) + 1)], []) + # [{"category":category,"market": int,"code":sting,"name":string,"desc":string},{}] # 对所有合约处理,更新字典 指数合约-tdx市场,指数合约-交易所 for tdx_contract in all_contacts: @@ -178,8 +191,6 @@ class TdxFutureData(object): if str(tdx_market_id) in Tdx_Vn_Exchange_Map: self.symbol_exchange_dict.update({tdx_symbol: Tdx_Vn_Exchange_Map.get(str(tdx_market_id))}) self.symbol_market_dict.update({tdx_symbol: tdx_market_id}) - #if 'L9' in tdx_symbol: - # print('\'{}\':{},'.format(tdx_symbol, Tdx_Vn_Exchange_Map.get(str(tdx_market_id)))) # ---------------------------------------------------------------------- def get_bars(self, symbol, period, callback, bar_is_completed=False, bar_freq=1, start_dt=None): @@ -190,30 +201,32 @@ class TdxFutureData(object): """ ret_bars = [] - tdx_symbol = symbol.upper().replace('_' , '') - tdx_symbol = tdx_symbol.replace('99' , 'L9') + tdx_symbol = symbol.upper().replace('_', '') + tdx_symbol = tdx_symbol.replace('99', 'L9') self.connect() if self.api is None: return False, ret_bars if tdx_symbol not in self.symbol_exchange_dict.keys(): - self.write_error(u'{} 合约{}/{}不在下载清单中: {}'.format(datetime.now(), symbol, tdx_symbol, self.symbol_exchange_dict.keys())) - return False,ret_bars + self.write_error(u'{} 合约{}/{}不在下载清单中: {}' + .format(datetime.now(), symbol, tdx_symbol, self.symbol_exchange_dict.keys())) + return False, ret_bars if period not in PERIOD_MAPPING.keys(): self.write_error(u'{} 周期{}不在下载清单中: {}'.format(datetime.now(), period, list(PERIOD_MAPPING.keys()))) - return False,ret_bars + return False, ret_bars - tdx_period = PERIOD_MAPPING.get(period) + # tdx_period = PERIOD_MAPPING.get(period) if start_dt is None: self.write_log(u'没有设置开始时间,缺省为10天前') qry_start_date = datetime.now() - timedelta(days=10) else: qry_start_date = start_dt - end_date = datetime.combine(datetime.now() + timedelta(days=1),time(ALL_MARKET_END_HOUR, 0)) + end_date = datetime.combine(datetime.now() + timedelta(days=1), time(ALL_MARKET_END_HOUR, 0)) if qry_start_date > end_date: qry_start_date = end_date - self.write_log('{}开始下载tdx:{} {}数据, {} to {}.'.format(datetime.now(), tdx_symbol, period, qry_start_date, end_date)) + self.write_log('{}开始下载tdx:{} {}数据, {} to {}.' + .format(datetime.now(), tdx_symbol, period, qry_start_date, end_date)) # print('{}开始下载tdx:{} {}数据, {} to {}.'.format(datetime.now(), tdx_symbol, tdx_period, last_date, end_date)) try: @@ -223,7 +236,7 @@ class TdxFutureData(object): while _start_date > qry_start_date: _res = self.api.get_instrument_bars( PERIOD_MAPPING[period], - self.symbol_market_dict.get(tdx_symbol,0), + self.symbol_market_dict.get(tdx_symbol, 0), tdx_symbol, _pos, QSIZE) @@ -329,7 +342,7 @@ class TdxFutureData(object): return True, ret_bars except Exception as ex: - self.write_error('exception in get:{},{},{}'.format(tdx_symbol,str(ex), traceback.format_exc())) + self.write_error('exception in get:{},{},{}'.format(tdx_symbol, str(ex), traceback.format_exc())) # print('exception in get:{},{},{}'.format(tdx_symbol,str(ex), traceback.format_exc())) self.write_log(u'重置连接') self.api = None @@ -349,13 +362,13 @@ class TdxFutureData(object): if query_symbol != tdx_symbol: self.write_log('转换合约:{}=>{}'.format(tdx_symbol, query_symbol)) - index_symbol = short_symbol+'L9' + index_symbol = short_symbol + 'L9' self.connect() if self.api is None: return 0 - market_id = self.symbol_market_dict.get(index_symbol,0) + market_id = self.symbol_market_dict.get(index_symbol, 0) - _res = self.api.get_instrument_quote(market_id,query_symbol) + _res = self.api.get_instrument_quote(market_id, query_symbol) if not isinstance(_res, list): return 0 if len(_res) == 0: @@ -383,7 +396,7 @@ class TdxFutureData(object): def get_contracts(self, exchange): self.connect() - market_id = Vn_Tdx_Exchange_Map.get(exchange,None) + market_id = Vn_Tdx_Exchange_Map.get(exchange, None) if market_id is None: print(u'市场:{}配置不在Vn_Tdx_Exchange_Map:{}中,不能取市场下所有合约'.format(exchange, Vn_Tdx_Exchange_Map)) return [] @@ -392,7 +405,7 @@ class TdxFutureData(object): count = 100 results = [] while(True): - print(u'查询{}下:{}~{}个合约'.format(exchange, index, index+count)) + print(u'查询{}下:{}~{}个合约'.format(exchange, index, index + count)) result = self.api.get_instrument_quote_list(int(market_id), 3, index, count) results.extend(result) index += count @@ -423,12 +436,13 @@ class TdxFutureData(object): for contract in contracts: # 排除指数合约 code = contract.get('code') - if code[-2:] in ['L9','L8','L0','L1','L2','L3','50'] or (exchange == Exchange.CFFEX and code[-3:] in ['300', '500']): + if code[-2:] in ['L9', 'L8', 'L0', 'L1', 'L2', 'L3', '50'] or\ + (exchange == Exchange.CFFEX and code[-3:] in ['300', '500']): continue short_symbol = get_underlying_symbol(code).upper() - contract_list = short_contract_dict.get(short_symbol,[]) + contract_list = short_contract_dict.get(short_symbol, []) contract_list.append(contract) - short_contract_dict.update({short_symbol:contract_list}) + short_contract_dict.update({short_symbol: contract_list}) for k, v in short_contract_dict.items(): sorted_list = sorted(v, key=lambda c: c['ZongLiang']) @@ -468,7 +482,7 @@ class TdxFutureData(object): while(True): _res = self.api.get_transaction_data( - market=self.symbol_market_dict.get(symbol,0), + market=self.symbol_market_dict.get(symbol, 0), code=symbol, start=_pos, count=q_size) @@ -476,7 +490,7 @@ class TdxFutureData(object): for d in _res: dt = d.pop('date') # 星期1~星期6 - if dt.hour >= 20 and 1< dt.isoweekday()<=6: + if dt.hour >= 20 and 1 < dt.isoweekday() <= 6: dt = dt - timedelta(days=1) elif dt.hour >= 20 and dt.isoweekday() == 1: # 星期一取得20点后数据 @@ -495,12 +509,13 @@ class TdxFutureData(object): # 接口有bug,返回价格*1000,所以要除以1000 d.update({'price': d.get('price', 0) / 1000}) _datas = sorted(_res, key=lambda s: s['datetime']) + _datas - _pos += min(q_size,len(_res)) + _pos += min(q_size, len(_res)) if _res is not None and len(_res) > 0: - self.write_log(u'分段取分笔数据:{} ~{}, {}条,累计:{}条'.format( _res[0]['datetime'],_res[-1]['datetime'], len(_res),_pos)) + self.write_log(u'分段取分笔数据:{} ~{}, {}条,累计:{}条' + .format(_res[0]['datetime'], _res[-1]['datetime'], len(_res), _pos)) else: - break + break if len(_datas) >= max_data_size: break @@ -551,6 +566,7 @@ class TdxFutureData(object): if not os.path.isfile(cache_file): self.write_error('缓存文件:{}不存在,不能读取'.format(cache_file)) return None + with bz2.BZ2File(cache_file, 'rb') as f: data = pickle.load(f) return data @@ -596,7 +612,7 @@ class TdxFutureData(object): while(True): _res = self.api.get_history_transaction_data( - market=self.symbol_market_dict.get(symbol,0), + market=self.symbol_market_dict.get(symbol, 0), date=date, code=symbol, start=_pos, @@ -605,9 +621,9 @@ class TdxFutureData(object): for d in _res: dt = d.pop('date') # 星期1~星期6 - if dt.hour >= 20 and 1< dt.isoweekday()<=6: + if dt.hour >= 20 and 1 < dt.isoweekday() <= 6: dt = dt - timedelta(days=1) - d.update({'datetime':dt}) + d.update({'datetime': dt}) elif dt.hour >= 20 and dt.isoweekday() == 1: # 星期一取得20点后数据 dt = dt - timedelta(days=3) @@ -627,21 +643,22 @@ class TdxFutureData(object): else: d.update({'datetime': dt}) # 接口有bug,返回价格*1000,所以要除以1000 - d.update({'price': d.get('price', 0)/1000}) + d.update({'price': d.get('price', 0) / 1000}) _datas = sorted(_res, key=lambda s: s['datetime']) + _datas _pos += min(q_size, len(_res)) if _res is not None and len(_res) > 0: - self.write_log(u'分段取分笔数据:{} ~{}, {}条,累计:{}条'.format( _res[0]['datetime'],_res[-1]['datetime'], len(_res),_pos)) + self.write_log(u'分段取分笔数据:{} ~{}, {}条,累计:{}条' + .format(_res[0]['datetime'], _res[-1]['datetime'], len(_res), _pos)) else: - break + break if len(_datas) >= max_data_size: break if len(_datas) == 0: self.write_error(u'{}分笔成交数据获取为空'.format(date)) - return False,_datas + return False, _datas # 缓存文件 if cache_folder: @@ -650,13 +667,15 @@ class TdxFutureData(object): return True, _datas except Exception as ex: - self.write_error('exception in get_transaction_data:{},{},{}'.format(symbol, str(ex), traceback.format_exc())) + self.write_error('exception in get_transaction_data:{},{},{}' + .format(symbol, str(ex), traceback.format_exc())) self.write_error(u'当前异常服务器信息:{}'.format(self.best_ip)) self.write_log(u'重置连接') self.api = None self.connect(is_reconnect=True) return False, ret_datas + class FakeStrategy(object): def write_log(self, content, level=INFO): @@ -668,6 +687,7 @@ class FakeStrategy(object): def display_bar(self, bar, bar_is_completed=True, freq=1): print(u'{} {}'.format(bar.vtSymbol, bar.datetime)) + if __name__ == "__main__": t1 = FakeStrategy() @@ -687,24 +707,24 @@ if __name__ == "__main__": print('price={}'.format(price)) exit(0) # 获取主力合约 - #result = api_01.get_mi_contracts() - #str_result = json.dumps(result,indent=1, ensure_ascii=False) - #print(str_result) + # result = api_01.get_mi_contracts() + # str_result = json.dumps(result,indent=1, ensure_ascii=False) + # print(str_result) # 获取某个板块的合约 - #result = api_01.get_contracts(exchange=EXCHANGE_CZCE) + # result = api_01.get_contracts(exchange=EXCHANGE_CZCE) # 获取某个板块的主力合约 - #result = api_01.get_mi_contracts_from_exchange(exchange=EXCHANGE_CZCE) + # result = api_01.get_mi_contracts_from_exchange(exchange=EXCHANGE_CZCE) # 获取主力合约(从各个板块组合获取) - #result = api_01.get_mi_contracts2() - #print(u'一共{}个记录:{}'.format(len(result), [c.get('code') for c in result])) - #str_result = json.dumps(result,indent=1, ensure_ascii=False) - #print(str_result) + # result = api_01.get_mi_contracts2() + # print(u'一共{}个记录:{}'.format(len(result), [c.get('code') for c in result])) + # str_result = json.dumps(result,indent=1, ensure_ascii=False) + # print(str_result) - #all_99_ticks= api_01.get_99_contracts() - #str_99_ticks = json.dumps(all_99_ticks, indent=1, ensure_ascii=False) - #print(u'{}'.format(str_99_ticks)) + # all_99_ticks= api_01.get_99_contracts() + # str_99_ticks = json.dumps(all_99_ticks, indent=1, ensure_ascii=False) + # print(u'{}'.format(str_99_ticks)) # 获取历史分钟线 """ @@ -717,18 +737,17 @@ if __name__ == "__main__": corr_rate = round(abs(corr.iloc[0, 1]) * 100, 2) """ # api.get_bars(symbol, period='5min', callback=display_bar) - #api_01.get_bars('IF99', period='1day', callback=t1.display_bar) - #result,datas = api_01.get_transaction_data(symbol='ni1905') - #api_02 = TdxFutureData(t2) - #api_02.get_bars('IF99', period='1min', callback=t1.display_bar) + # api_01.get_bars('IF99', period='1day', callback=t1.display_bar) + # result,datas = api_01.get_transaction_data(symbol='ni1905') + # api_02 = TdxFutureData(t2) + # api_02.get_bars('IF99', period='1min', callback=t1.display_bar) # 获取当前交易日分时数据 - #ret,result = api_01.get_transaction_data('RB99') - #for r in result[0:10] + result[-10:]: - # print(r) + # ret,result = api_01.get_transaction_data('RB99') + # for r in result[0:10] + result[-10:]: + # print(r) # 获取历史分时数据 - ret,result = api_01.get_history_transaction_data('J99', '20191009') + ret, result = api_01.get_history_transaction_data('J99', '20191009') for r in result[0:10] + result[-10:]: print(r) - diff --git a/vnpy/data/tdx/tdx_stock_data.py b/vnpy/data/tdx/tdx_stock_data.py index fb2b4b98..42e56d0e 100644 --- a/vnpy/data/tdx/tdx_stock_data.py +++ b/vnpy/data/tdx/tdx_stock_data.py @@ -10,7 +10,11 @@ # https://rainx.gitbooks.io/pytdx/content/pytdx_hq.html # 华富资产 -import sys, os, pickle, bz2, traceback +import sys +import os +import pickle +import bz2 +import traceback from datetime import datetime, timedelta from logging import ERROR, INFO from pytdx.hq import TdxHq_API @@ -21,16 +25,17 @@ from vnpy.data.tdx.tdx_common import PERIOD_MAPPING, get_tdx_market_code # 每个周期包含多少分钟 NUM_MINUTE_MAPPING = {} -NUM_MINUTE_MAPPING['1min'] = 1 -NUM_MINUTE_MAPPING['5min'] = 5 -NUM_MINUTE_MAPPING['15min'] = 15 -NUM_MINUTE_MAPPING['30min'] = 30 -NUM_MINUTE_MAPPING['1hour'] = 60 -NUM_MINUTE_MAPPING['1day'] = 60*5.5 # 股票,收盘时间是15:00,开盘是9:30 +NUM_MINUTE_MAPPING['1min'] = 1 +NUM_MINUTE_MAPPING['5min'] = 5 +NUM_MINUTE_MAPPING['15min'] = 15 +NUM_MINUTE_MAPPING['30min'] = 30 +NUM_MINUTE_MAPPING['1hour'] = 60 +NUM_MINUTE_MAPPING['1day'] = 60 * 5.5 # 股票,收盘时间是15:00,开盘是9:30 # 常量 QSIZE = 800 + class TdxStockData(object): best_ip = None symbol_exchange_dict = {} # tdx合约与vn交易所的字典 @@ -69,7 +74,7 @@ class TdxStockData(object): """ # 创建api连接对象实例 try: - if self.api is None or self.connection_status == False: + if self.api is None or not self.connection_status: self.write_log(u'开始连接通达信股票行情服务器') self.api = TdxHq_API(heartbeat=True, auto_retry=True, raise_exception=True) @@ -88,7 +93,7 @@ class TdxStockData(object): def disconnect(self): if self.api is not None: - self.api= None + self.api = None # ---------------------------------------------------------------------- def get_bars(self, symbol, period, callback, bar_is_completed=False, bar_freq=1, start_dt=None): @@ -102,10 +107,10 @@ class TdxStockData(object): # 新版一劳永逸偷懒写法zzz if '.' in symbol: - tdx_code,market_str = symbol.split('.') - market_code = 1 if market_str.upper()== 'XSHG' else 0 - self.symbol_exchange_dict.update({tdx_code:symbol}) # tdx合约与vn交易所的字典 - self.symbol_market_dict.update({tdx_code:market_code}) # tdx合约与tdx市场的字典 + tdx_code, market_str = symbol.split('.') + market_code = 1 if market_str.upper() == 'XSHG' else 0 + self.symbol_exchange_dict.update({tdx_code: symbol}) # tdx合约与vn交易所的字典 + self.symbol_market_dict.update({tdx_code: market_code}) # tdx合约与tdx市场的字典 else: market_code = get_tdx_market_code(symbol) tdx_code = symbol @@ -118,12 +123,13 @@ class TdxStockData(object): ret_bars = [] if period not in PERIOD_MAPPING.keys(): - self.write_log(u'{} 周期{}不在下载清单中: {}'.format(datetime.now(), period, list(PERIOD_MAPPING.keys())), level=ERROR) + self.write_error(u'{} 周期{}不在下载清单中: {}' + .format(datetime.now(), period, list(PERIOD_MAPPING.keys()))) # print(u'{} 周期{}不在下载清单中: {}'.format(datetime.now(), period, list(PERIOD_MAPPING.keys()))) - return False,ret_bars + return False, ret_bars if self.api is None: - return False,ret_bars + return False, ret_bars tdx_period = PERIOD_MAPPING.get(period) @@ -137,14 +143,16 @@ class TdxStockData(object): if qry_start_date > end_date: qry_start_date = end_date - self.write_log('{}开始下载tdx股票:{} {}数据, {} to {}.'.format(datetime.now(), tdx_code, tdx_period, qry_start_date, end_date)) + self.write_log('{}开始下载tdx股票:{} {}数据, {} to {}.' + .format(datetime.now(), tdx_code, tdx_period, qry_start_date, end_date)) try: _start_date = end_date _bars = [] _pos = 0 while _start_date > qry_start_date: - _res = self.api.get_security_bars(category=PERIOD_MAPPING[period], + _res = self.api.get_security_bars( + category=PERIOD_MAPPING[period], market=market_code, code=tdx_code, start=_pos, @@ -169,20 +177,19 @@ class TdxStockData(object): data = data.assign(ticker=symbol) data['symbol'] = symbol data = data.drop( - ['year', 'month', 'day', 'hour', 'minute', 'price', 'ticker'], + ['year', 'month', 'day', 'hour', 'minute', 'price', 'ticker'], errors='ignore', axis=1) data = data.rename( index=str, - columns={'amount': 'volume', - }) + columns={'amount': 'volume'}) if len(data) == 0: print('{} Handling {}, len2={}..., continue'.format( str(datetime.now()), tdx_code, len(data))) return False, ret_bars # 通达信是以bar的结束时间标记的,vnpy是以bar开始时间标记的,所以要扣减bar本身的分钟数 - data['datetime'] = data['datetime'].apply(lambda x:x-timedelta(minutes=NUM_MINUTE_MAPPING.get(period,1))) + data['datetime'] = data['datetime'].apply(lambda x: x - timedelta(minutes=NUM_MINUTE_MAPPING.get(period, 1))) data['trading_date'] = data['datetime'].apply(lambda x: (x.strftime('%Y-%m-%d'))) data['date'] = data['datetime'].apply(lambda x: (x.strftime('%Y-%m-%d'))) data['time'] = data['datetime'].apply(lambda x: (x.strftime('%H:%M:%S'))) @@ -202,7 +209,8 @@ class TdxStockData(object): add_bar.close = float(row['close']) add_bar.volume = float(row['volume']) except Exception as ex: - self.write_error('error when convert bar:{},ex:{},t:{}'.format(row, str(ex), traceback.format_exc())) + self.write_error('error when convert bar:{},ex:{},t:{}' + .format(row, str(ex), traceback.format_exc())) # print('error when convert bar:{},ex:{},t:{}'.format(row, str(ex), traceback.format_exc())) return False @@ -224,9 +232,9 @@ class TdxStockData(object): freq = NUM_MINUTE_MAPPING[period] - int((index - current_datetime).total_seconds() / 60) callback(add_bar, bar_is_completed, freq) - return True,ret_bars + return True, ret_bars except Exception as ex: - self.write_error('exception in get:{},{},{}'.format(tdx_code,str(ex), traceback.format_exc())) + self.write_error('exception in get:{},{},{}'.format(tdx_code, str(ex), traceback.format_exc())) # print('exception in get:{},{},{}'.format(tdx_symbol,str(ex), traceback.format_exc())) self.write_log(u'重置连接') TdxStockData.api = None @@ -236,7 +244,7 @@ class TdxStockData(object): def save_cache(self, cache_folder, cache_symbol, cache_date, data_list): """保存文件到缓存""" - os.makedirs(cache_folder,exist_ok=True) + os.makedirs(cache_folder, exist_ok=True) if not os.path.exists(cache_folder): self.write_error('缓存目录不存在:{},不能保存'.format(cache_folder)) @@ -332,16 +340,17 @@ class TdxStockData(object): _pos += min(q_size, len(_res)) if _res is not None and len(_res) > 0: - self.write_log(u'分段取{}分笔数据:{} ~{}, {}条,累计:{}条'.format(date, _res[0]['time'],_res[-1]['time'], len(_res),_pos)) + self.write_log(u'分段取{}分笔数据:{} ~{}, {}条,累计:{}条' + .format(date, _res[0]['time'], _res[-1]['time'], len(_res), _pos)) else: - break + break if len(_datas) >= max_data_size: break if len(_datas) == 0: self.write_error(u'{}分笔成交数据获取为空'.format(date)) - return False,_datas + return False, _datas for d in _datas: dt = datetime.strptime(str(date) + ' ' + d.get('time'), '%Y%m%d %H:%M') @@ -351,7 +360,7 @@ class TdxStockData(object): if last_dt < dt + timedelta(seconds=59): last_dt = last_dt + timedelta(seconds=1) d.update({'datetime': last_dt}) - d.update({'volume': d.pop('vol',0)}) + d.update({'volume': d.pop('vol', 0)}) d.update({'trading_date': last_dt.strftime('%Y-%m-%d')}) _datas = sorted(_datas, key=lambda s: s['datetime']) @@ -366,16 +375,17 @@ class TdxStockData(object): self.write_error('exception in get_transaction_data:{},{},{}'.format(symbol, str(ex), traceback.format_exc())) return False, ret_datas + if __name__ == "__main__": class T(object): - def write_log(self,content, level=INFO): + def write_log(self, content, level=INFO): if level == INFO: print(content) else: - print(content,file=sys.stderr) + print(content, file=sys.stderr) - def display_bar(self,bar, bar_is_completed=True, freq=1): - print(u'{} {}'.format(bar.vtSymbol,bar.datetime)) + def display_bar(self, bar, bar_is_completed=True, freq=1): + print(u'{} {}'.format(bar.vtSymbol, bar.datetime)) t1 = T() t2 = T() @@ -383,12 +393,12 @@ if __name__ == "__main__": api_01 = TdxStockData(t1) # 获取历史分钟线 - #api_01.get_bars('002024', period='1hour', callback=t1.display_bar) + # api_01.get_bars('002024', period='1hour', callback=t1.display_bar) # api.get_bars(symbol, period='5min', callback=display_bar) # api.get_bars(symbol, period='1day', callback=display_bar) - #api_02 = TdxData(t2) - #api_02.get_bars('601390', period='1day', callback=t1.display_bar) + # api_02 = TdxData(t2) + # api_02.get_bars('601390', period='1day', callback=t1.display_bar) # 获取历史分时数据 # ret,result = api_01.get_history_transaction_data('RB99', '20190909') @@ -399,4 +409,3 @@ if __name__ == "__main__": ret, result = api_01.get_history_transaction_data('600410', '20190925') for r in result[0:10] + result[-10:]: print(r) - diff --git a/vnpy/event/engine.py b/vnpy/event/engine.py index 6a7854c4..6ad2fa8d 100644 --- a/vnpy/event/engine.py +++ b/vnpy/event/engine.py @@ -98,7 +98,7 @@ class EventEngine: execute_ms = (int(round(t2 * 1000))) - (int(round(t1 * 1000))) if execute_ms > self._over_ms: print(f'运行 general {event.type} {handler_name} 耗时:{execute_ms}ms > {self._over_ms}ms', - file=sys.stderr) + file=sys.stderr) def _process(self, event: Event): """ diff --git a/vnpy/trader/util_logger.py b/vnpy/trader/util_logger.py new file mode 100644 index 00000000..d87e1bdf --- /dev/null +++ b/vnpy/trader/util_logger.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python +# coding=utf8 + +import os +import sys +import re +import logging +import threading +import multiprocessing +from datetime import datetime + + +RECORD_FORMAT = "%(levelname)s [%(asctime)s][%(filename)s:%(lineno)d] %(message)s" +BACKTEST_FORMAT = "%(levelname)s %(message)s" +DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + +_logger = None +_fileHandler = None +_logger_filename = None + +thread_data = threading.local() + + +class MultiprocessHandler(logging. FileHandler): + """支持多进程的TimedRotatingFileHandler""" + def __init__(self, filename: str, + interval: str = 'D', + backup_count: int = 0, + encoding: str = None, + delay: bool = False): + """filename 日志文件名, + interval 时间间隔的单位, + backup_count 保留文件个数,0表示不删除 + delay 是否开启 OutSteam缓存 + True 表示开启缓存,OutStream输出到缓存,待缓存区满后,刷新缓存区,并输出缓存数据到文件。 + False表示不缓存,OutStrea直接输出到文件""" + self.prefix = filename + self.backup_count = backup_count + self.interval = interval.upper() + # 正则匹配 年-月-日 + self.re_match = r"^\d{4}-\d{2}-\d{2}" + + # S 每秒建立一个新文件 + # M 每分钟建立一个新文件 + # H 每天建立一个新文件 + # D 每天建立一个新文件 + self.interval_formater_dict = { + 'S': "%Y-%m-%d-%H-%M-%S", + 'M': "%Y-%m-%d-%H-%M", + 'H': "%Y-%m-%d-%H", + 'D': "%Y-%m-%d" + } + + # 日志文件日期后缀 + self.formater = self.interval_formater_dict.get(interval) + if not self.formater: + raise ValueError(u"指定的日期间隔单位无效: %s" % self.interval) + + # 使用当前时间,格式化文件格式化字符串 + self.file_path = u'{}_{}.log'.format(self.prefix, datetime.now().strftime(self.formater)) + # 获得文件夹路径 + _dir = os.path.dirname(self.file_path) + try: + # 如果日志文件夹不存在,则创建文件夹 + if not os.path.exists(_dir): + os.makedirs(_dir) + except Exception as ex: + print(f'创建log文件夹{self.file_path}失败:{str(ex)}', file=sys.stderr) + pass + + print(u'MultiprocessHandler create logger:{}'.format(self.file_path)) + + logging.FileHandler.__init__(self, self.file_path, 'a+', encoding, delay) + + def should_change_file(self): + """更改日志写入目的写入文件 + :return True 表示已更改,False 表示未更改""" + # 以当前时间获得新日志文件路径 + _filePath = u'{}_{}.log'.format(self.prefix, datetime.now().strftime(self.formater)) + # 新日志文件日期 不等于 旧日志文件日期,则表示 已经到了日志切分的时候 + # 更换日志写入目的为新日志文件。 + # 例如 按 天 (D)来切分日志 + # 当前新日志日期等于旧日志日期,则表示在同一天内,还不到日志切分的时候 + # 当前新日志日期不等于旧日志日期,则表示不在 + # 同一天内,进行日志切分,将日志内容写入新日志内。 + if _filePath != self.file_path: + self.file_path = _filePath + return True + return False + + def do_change_file(self): + """输出信息到日志文件,并删除多于保留个数的所有日志文件""" + # 日志文件的绝对路径 + self.baseFilename = os.path.abspath(self.file_path) + # stream == OutStream + # stream is not None 表示 OutStream中还有未输出完的缓存数据 + if self.stream: + # flush close 都会刷新缓冲区,flush不会关闭stream,close则关闭stream + # self.stream.flush() + self.stream.close() + # 关闭stream后必须重新设置stream为None,否则会造成对已关闭文件进行IO操作。 + self.stream = None + # delay 为False 表示 不OutStream不缓存数据 直接输出 + # 所有,只需要关闭OutStream即可 + if not self.delay: + # 这个地方如果关闭colse那么就会造成进程往已关闭的文件中写数据,从而造成IO错误 + # delay == False 表示的就是 不缓存直接写入磁盘 + # 我们需要重新在打开一次stream + # self.stream.close() + self.stream = self._open() + # 删除多于保留个数的所有日志文件 + if self.backup_count > 0: + print('删除日志') + for s in self.get_expired_files(): + print(s) + os.remove(s) + + def get_expired_files(self): + """获得过期需要删除的日志文件""" + # 分离出日志文件夹绝对路径 + # split返回一个元组(absFilePath,fileName) + # 例如:split('I:\ScripPython\char4\mybook\util\logs\mylog.2017-03-19) + # 返回(I:\ScripPython\char4\mybook\util\logs, mylog.2017-03-19) + # _ 表示占位符,没什么实际意义, + dir_name, _ = os.path.split(self.baseFilename) + file_names = os.listdir(dir_name) + result = [] + # self.prefix 为日志文件名 列如:mylog.2017-03-19 中的 mylog + # 加上 点号 . 方便获取点号后面的日期 + prefix = self.prefix + '.' + plen = len(prefix) + for file_name in file_names: + if file_name[:plen] == prefix: + # 日期后缀 mylog.2017-03-19 中的 2017-03-19 + suffix = file_name[plen:] + # 匹配符合规则的日志文件,添加到result列表中 + if re.compile(self.re_match).match(suffix): + result.append(os.path.join(dir_name, file_name)) + result.sort() + + # 返回 待删除的日志文件 + # 多于 保留文件个数 backupCount的所有前面的日志文件。 + if len(result) < self.backup_count: + result = [] + else: + result = result[:len(result) - self.backup_count] + return result + + def emit(self, record): + """发送一个日志记录 + 覆盖FileHandler中的emit方法,logging会自动调用此方法""" + try: + if self.should_change_file(): + self.do_change_file() + logging.FileHandler.emit(self, record) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + self.handleError(record) + + +def setup_logger(filename, name=None, debug=False, force=False, backtesing=False): + """ + 设置日志文件,包括路径 + 自动在后面添加 "_日期.log" + :param logger_file_name: + :return: + """ + + global _logger + global _fileHandler + global _logger_filename + + if _logger is not None and _logger_filename == filename and not force: + return _logger + + if _logger_filename != filename or force: + if force: + _logger_filename = filename + + # 定义日志输出格式 + fmt = logging.Formatter(RECORD_FORMAT if not backtesing else BACKTEST_FORMAT) + if name is None: + names = filename.replace('.log', '').split('/') + name = names[-1] + + logger = logging.getLogger(name) + if debug: + logger.setLevel(logging.DEBUG) + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setLevel(logging.DEBUG) + stream_handler.setFormatter(fmt) + if not logger.hasHandlers(): + logger.addHandler(stream_handler) + + fileHandler = MultiprocessHandler(filename, encoding='utf8', interval='D') + if debug: + fileHandler.setLevel(logging.DEBUG) + else: + fileHandler.setLevel(logging.WARNING) + + fileHandler.setFormatter(fmt) + logger.addHandler(fileHandler) + + if debug: + logger.setLevel(logging.DEBUG) + else: + logger.setLevel(logging.WARNING) + + return logger + + return _logger + + +def get_logger(name=None): + global _logger + + if _logger is None: + _logger = logging.getLogger(name) + return _logger + + if _logger.name != name: + return logging.getLogger(name) + + return _logger + + +# -------------------测试代码------------ +def single_func(para): + setup_logger('logs/MyLog{}'.format(para)) + logger = get_logger() + if para > 5: + # print u'more than 5' + logger.info(u'{}大于 More than 5'.format(para)) + return True + else: + # print 'less' + logger.info('{}Less than 5'.format(para)) + return False + + +def multi_func(): + # 启动多进程 + pool = multiprocessing.Pool(multiprocessing.cpu_count()) + setup_logger('logs/MyLog') + logger = get_logger() + logger.info('main process') + task_list = [] + + for i in range(0, 10): + task_list.append(pool.apply_async(single_func, (i,))) + + [res.get() for res in task_list] + + pool.close() + pool.join() + + +if __name__ == '__main__': + + # 创建启动主进程日志 + setup_logger('logs/MyLog') + logger = get_logger() + logger.info("info into multiprocessing") + + # 测试所有多进程日志 + multi_func() diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index c54564dc..f805f0e4 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -42,6 +42,7 @@ def func_time(over_ms: int = 0): return wrapper return run + @lru_cache() def get_underlying_symbol(symbol: str): """ @@ -72,6 +73,7 @@ def get_underlying_symbol(symbol: str): return underlying_symbol.group(1) + @lru_cache() def get_full_symbol(symbol: str): """