diff --git a/vnpy/app/csv_loader/engine.py b/vnpy/app/csv_loader/engine.py index af4a49a8..2c384d02 100644 --- a/vnpy/app/csv_loader/engine.py +++ b/vnpy/app/csv_loader/engine.py @@ -23,9 +23,11 @@ Sample csv file: import csv from datetime import datetime +from peewee import chunked + from vnpy.event import EventEngine from vnpy.trader.constant import Exchange, Interval -from vnpy.trader.database import DbBarData +from vnpy.trader.database import DbBarData, DB from vnpy.trader.engine import BaseEngine, MainEngine @@ -75,30 +77,37 @@ class CsvLoaderEngine(BaseEngine): with open(file_path, 'rt') as f: reader = csv.DictReader(f) + db_bars = [] + for item in reader: - db_bar = DbBarData() + dt = datetime.strptime(item[datetime_head], datetime_format) - db_bar.symbol = symbol - db_bar.exchange = exchange.value - db_bar.datetime = datetime.strptime( - item[datetime_head], datetime_format - ) - db_bar.interval = interval.value - db_bar.volume = item[volume_head] - db_bar.open_price = item[open_head] - db_bar.high_price = item[high_head] - db_bar.low_price = item[low_head] - db_bar.close_price = item[close_head] - db_bar.vt_symbol = vt_symbol - db_bar.gateway_name = "DB" + db_bar = { + "symbol": symbol, + "exchange": exchange.value, + "datetime": dt, + "interval": interval.value, + "volume": item[volume_head], + "open_price": item[open_head], + "high_price": item[high_head], + "low_price": item[low_head], + "close_price": item[close_head], + "vt_symbol": vt_symbol, + "gateway_name": "DB" + } - db_bar.replace() + db_bars.append(db_bar) # do some statistics count += 1 if not start: - start = db_bar.datetime + start = db_bar["datetime"] - end = db_bar.datetime + end = db_bar["datetime"] + + # Insert into DB + with DB.atomic(): + for batch in chunked(db_bars, 500): + DbBarData.insert_many(batch).on_conflict_replace().execute() return start, end, count diff --git a/vnpy/app/csv_loader/ui/widget.py b/vnpy/app/csv_loader/ui/widget.py index 6e89741c..30160d77 100644 --- a/vnpy/app/csv_loader/ui/widget.py +++ b/vnpy/app/csv_loader/ui/widget.py @@ -27,8 +27,8 @@ class CsvLoaderWidget(QtWidgets.QWidget): self.setFixedWidth(300) self.setWindowFlags( - (self.windowFlags() | QtCore.Qt.CustomizeWindowHint) - & ~QtCore.Qt.WindowMaximizeButtonHint) + (self.windowFlags() | QtCore.Qt.CustomizeWindowHint) & + ~QtCore.Qt.WindowMaximizeButtonHint) file_button = QtWidgets.QPushButton("选择文件") file_button.clicked.connect(self.select_file) @@ -90,7 +90,8 @@ class CsvLoaderWidget(QtWidgets.QWidget): def select_file(self): """""" - result: str = QtWidgets.QFileDialog.getOpenFileName(self) + result: str = QtWidgets.QFileDialog.getOpenFileName( + self, filter="CSV (*.csv)") filename = result[0] if filename: self.file_edit.setText(filename)