from dataclasses import asdict
from decimal import ROUND_HALF_UP, Decimal
from pyqqq.brokerage.ebest.simple import EBestSimpleDomesticStock
from pyqqq.brokerage.kis.simple import KISSimpleDomesticStock
from pyqqq.data.domestic import get_tickers
from pyqqq.datatypes import *
from pyqqq.utils.api_client import raise_for_status, send_request
from pyqqq.utils.array import find
from pyqqq.utils.logger import get_logger
from pyqqq.utils.market_schedule import get_market_schedule, get_last_trading_day
from pyqqq.utils.singleton import singleton
from typing import Dict, Optional, List
import asyncio
import os
import pyqqq.config as c
import datetime as dtm
@singleton
class TrackerSocket:
"""
거래 내역 추적을 위한 WebSocket 소켓 클래스입니다.
Args:
simple_api (EBestSimpleDomesticStock | KISSimpleDomesticStock): 간편 거래 API 객체
"""
def __init__(self, simple_api: EBestSimpleDomesticStock | KISSimpleDomesticStock):
self.simple_api = simple_api
self.task: asyncio.Task = None
self.stop_event = asyncio.Event()
self.logger = get_logger(__name__ + ".TrackerSocket")
self.trading_tracker_counter = 1
self.event_callbacks: Dict[callable] = {}
async def start(self):
"""
TradingSocket에서 거래내역 추적을 시작합니다.
"""
if self.task is not None and not self.task.done():
self.logger.info("TrackerSocket already started!")
return
self.task = asyncio.create_task(self._listen_order_event())
self.logger.info("TrackerSocket started!")
async def _listen_order_event(self):
"""
주문 이벤트를 수신하고 처리하는 비동기 메서드입니다.
"""
self.logger.info("Listening for order events...")
try:
async for event in self.simple_api.listen_order_event(self.stop_event):
self._relay_order_event(event)
except asyncio.CancelledError:
self.logger.info("Order event listening cancelled.")
except Exception as e:
self.logger.exception(f"Error while listening to order events: {e}")
def _relay_order_event(self, event: OrderEvent):
"""
주문 이벤트를 등록된 콜백 함수로 전달합니다.
Args:
event (OrderEvent): 주문 이벤트 객체
"""
self.logger.debug(f"Relay order event: {event}")
for callback in self.event_callbacks.values():
try:
callback(event)
except Exception as e:
self.logger.exception(f"Error in callback {callback}: {e}")
def add_tracker(self, callback: callable):
"""
TradingTracker를 추가합니다.
Args:
callback (callable): 거래 내역 추적 이벤트를 처리할 콜백 함수
"""
ret = self.trading_tracker_counter
self.event_callbacks[self.trading_tracker_counter] = callback
self.trading_tracker_counter += 1
self.logger.info(f"Tracker added: {ret}")
return ret
async def remove_tracker(self, tracker_number: int):
"""
TradingTracker를 제거합니다.
Args:
tracker_number (int): 제거할 트래커 번호
"""
if tracker_number in self.event_callbacks:
del self.event_callbacks[tracker_number]
self.logger.info(f"Tracker removed: {tracker_number}")
else:
self.logger.warning(f"Tracker {tracker_number} not found!")
if len(self.event_callbacks) == 0:
self.logger.info("No more trackers, stopping TrackerSocket.")
await self.stop()
async def stop(self):
"""
거래 내역 추적을 중지합니다.
"""
if self.task is None or self.task.done():
self.logger.info("TrackerSocket already stopped!")
return
self.stop_event.set()
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
self.logger.info("TrackerSocket stopped!")
self.task = asyncio.create_task(self._monitor_schedule())
[docs]
class TradingTracker:
"""
거래 내역 추적을 위한 클래스입니다
주문 이벤트를 수신하여 보유 포지션과 미체결 주문을 관리하고 거래 내역을 기록합니다.
Args:
simple_api (EBestSimpleDomesticStock | KISSimpleDomesticStock): 간편 거래 API 객체
fee_rate (Decimal): 증권사 수수료율 (기본값: 0.015%)
"""
logger = get_logger(__name__ + ".TradingTracker")
[docs]
def __init__(
self,
simple_api: EBestSimpleDomesticStock | KISSimpleDomesticStock,
fee_rate: Decimal = Decimal("0.00015"), # 뱅키스, LS증권 수수료율 0.015%
):
self.positions: List[StockPosition] = []
""" 보유 포지션 목록 """
self.pending_orders: List[StockOrder] = []
""" 미체결 주문 목록 """
self.on_pending_order_update: Optional[callable] = None
""" (deprecated) 미체결 주문 업데이트 이벤트 callback """
self.on_position_update: Optional[callable] = None
""" (deprecated) 포지션 업데이트 callback """
self.on_pending_order_update_callback_dict: Dict[callable] = {}
""" 미체결 주문 업데이트 이벤트 callback dict
key: callback_id
value: callback[callable]
Args:
status (str): 이벤트 상태. 'accepted', 'cancelled', 'completed', 'partial'
order (StockOrder): 주문 정보
"""
self.on_position_update_callback_dict: Dict[callable] = {}
""" 포지션 업데이트 이벤트 callback dict
key: callback_id
value: callback[callable]
Args:
type (str): 이벤트 타입. 'added', 'modified', 'removed'
position (StockPosition): 포지션 정보
"""
self.task: asyncio.Task = None
""" 백그라운드로 실행되는 거래 이벤트 모니터링 Task """
self.simple_api = simple_api
self.account_no = None
self.fee_rate = fee_rate # 증권사 수수료율
self.tax_rate = Decimal("0.0018") # KOSPI, KOSDAQ 매도시 거래세율 0.18%
self.tickers: Dict[str, Dict] = {} # 종목 코드별 종목 정보
self.ticker_date: dtm.datetime = None # 종목 정보 갱신 시간
self.save_trading_history = False # 거래 내역 저장 여부
self.last_t_day_tickers: Dict[str, Dict] = {}
self.tracker_socket = TrackerSocket(simple_api)
self.tracker_number = None
self.started = False
self.callback_id = 0
[docs]
async def start(self):
"""
거래 내역 추적을 시작합니다
"""
if self.started:
self.logger.info(f"Trading tracker already started!")
return
if isinstance(self.simple_api, EBestSimpleDomesticStock):
account_info = self.simple_api.get_account()
self.account_no = account_info["account_no"]
elif isinstance(self.simple_api, KISSimpleDomesticStock):
self.account_no = self.simple_api.account_no + self.simple_api.account_product_code
self.logger.info(f"Trading tracker started! Account No: {self.account_no} / save history: {self.save_trading_history}")
self._fetch_tickers()
self._sync_positions_and_pending_orders()
if len(self.pending_orders) > 0:
self.logger.info("Initial pending orders:")
for o in self.pending_orders:
self.logger.info(f"- {o.order_no}({o.org_order_no})\t{o.side}\t{o.asset_code}\t{o.filled_quantity}/{o.quantity}\t{o.is_pending}")
self.tracker_number = self.tracker_socket.add_tracker(self._handle_order_event)
self.tasks = [
asyncio.create_task(self.tracker_socket.start()),
asyncio.create_task(self._monitor_schedule()),
]
self.started = True
def enable_save_trading_history(self):
self.save_trading_history = True
def _fetch_tickers(self):
df = get_tickers(dtm.date.today())
df.reset_index(inplace=True)
for d in df.to_dict(orient="records"):
self.tickers[d["code"]] = d
self.ticker_date = dtm.datetime.now()
last_trading_day = get_last_trading_day(dtm.date.today())
last_t_day_df = get_tickers(last_trading_day)
last_t_day_df.reset_index(inplace=True)
for d in last_t_day_df.to_dict(orient="records"):
self.last_t_day_tickers[d["code"]] = d
def _sync_positions_and_pending_orders(self):
self.positions = self.simple_api.get_positions()
for p in self.positions:
p.current_price = None
p.current_value = None
p.current_pnl = None
p.current_pnl_value = None
self.pending_orders = self.simple_api.get_pending_orders(exchanges=list(OrderExchange))
[docs]
async def stop(self):
"""
거래 내역 추적을 중지합니다
"""
for t in self.tasks:
t.cancel()
await asyncio.gather(*self.tasks)
await self.tracker_socket.remove_tracker(self.tracker_number)
self.started = False
async def _monitor_schedule(self):
"""거래 시간대별 작업을 위한 스케줄을 모니터링합니다"""
while not self.tracker_socket.stop_event.is_set():
market_schedule = get_market_schedule(dtm.date.today())
if not market_schedule.full_day_closed:
# 정규장 시작 30분 전 종목정보 갱신
ticker_fresh_time = (dtm.datetime.combine(dtm.date.today(), market_schedule.open_time) - dtm.timedelta(minutes=30)).time()
ticker_refresh_seq_no = self._calc_clock_seq(t=ticker_fresh_time)
# 정규장 종료 시 주문, 포지션 정보 갱신
market_close_sync_seq_no = self._calc_clock_seq(t=market_schedule.close_time)
# 시간외 거래 종료 시 주문, 포지션 정보 갱신
after_market_close_sync_seq_no = self._calc_clock_seq(t=dtm.time(18, 0, 0))
try:
seq_no = self._calc_clock_seq()
if market_schedule.full_day_closed:
pass
elif seq_no == ticker_refresh_seq_no:
self._fetch_tickers()
self._sync_positions_and_pending_orders()
elif seq_no == market_close_sync_seq_no:
self._sync_positions_and_pending_orders()
self.save_positions()
elif seq_no == after_market_close_sync_seq_no:
self._sync_positions_and_pending_orders()
await asyncio.sleep(60)
except asyncio.CancelledError:
return
except Exception as e:
self.logger.exception(f"Error on monitoring schedule: {e}")
def _calc_clock_seq(
self,
interval=dtm.timedelta(seconds=60),
t: dtm.time = None,
):
"""현재 시각을 기준으로 시간대별 시퀀스 번호를 계산합니다"""
midnight = dtm.datetime.combine(dtm.date.today(), dtm.time.min)
clock_time = dtm.datetime.now() if t is None else dtm.datetime.combine(dtm.date.today(), t)
elapsed = clock_time - midnight
return int(elapsed.total_seconds() / interval.total_seconds())
def _find_pending_order(self, order_no) -> StockOrder:
return find(lambda x: x.order_no == order_no, self.pending_orders)
def _find_order_in_today_orders(self, order_no) -> StockOrder:
today_orders = self.simple_api.get_today_order_history(order_no=order_no, exchanges=[OrderExchange.KRX, OrderExchange.NXT, OrderExchange.SOR])
return find(lambda x: x.order_no == order_no, today_orders)
def _find_position(self, asset_code) -> StockPosition:
return find(lambda x: x.asset_code == asset_code, self.positions)
def _recalc_average_purchase_price(self, asset_code, quantity, price):
p = self._find_position(asset_code)
total_value = price * quantity
total_quantity = quantity
if p is not None:
prev_value = Decimal(p.average_purchase_price * p.quantity)
total_value += prev_value
total_quantity += p.quantity
return Decimal(total_value / total_quantity).quantize(Decimal("0.0001"), rounding=ROUND_HALF_UP)
def _handle_order_event(
self,
event: OrderEvent,
):
self.logger.debug(
f"handle_order_event: accno={event.account_no} order_no={event.order_no} (org={event.org_order_no})\tside={event.side}\tcode={event.asset_code}\tfilled={event.filled_quantity}\torder_qty={event.quantity}\tevent_type={event.event_type}"
)
if event.account_no != self.account_no:
return
if event.event_type == "accepted":
self._handle_accept_order_event(event)
elif event.event_type == "cancelled":
self._handle_cancel_order_event(event)
elif event.event_type == "executed":
self._handle_execution_order_event(event)
def _handle_accept_order_event(self, event: OrderEvent):
self.logger.debug(f"accept event: order_no={event.order_no} (org={event.org_order_no}) qty={event.quantity} side={event.side}")
order = StockOrder(
order_no=event.order_no,
asset_code=event.asset_code,
side=event.side,
order_type=event.order_type,
quantity=event.quantity,
price=event.price,
filled_quantity=0,
pending_quantity=event.quantity,
order_time=event.filled_time,
org_order_no=event.org_order_no,
exchange=event.exchange,
)
self.pending_orders.append(order)
org_order = self._find_pending_order(event.org_order_no)
if org_order is not None:
if org_order.pending_quantity == event.quantity:
self.pending_orders.remove(org_order)
else:
org_order.pending_quantity -= event.quantity
self._notify_pending_order_update("accepted", order)
if event.side == OrderSide.SELL:
position = self._find_position(event.asset_code)
if position is not None:
position.sell_possible_quantity -= event.quantity
def _handle_cancel_order_event(self, event: OrderEvent):
self.logger.debug(f"cancel event: order_no={event.order_no}")
order = self._find_pending_order(event.order_no)
if order is not None:
self.pending_orders.remove(order)
order.pending_quantity = 0
order.is_pending = False
if event.side == OrderSide.SELL:
position = self._find_position(event.asset_code)
if position is not None:
position.sell_possible_quantity += order.pending_quantity
self._notify_pending_order_update("cancelled", order)
def _handle_execution_order_event(self, event: OrderEvent):
order_no = event.order_no
side = event.side
asset_code = event.asset_code
total_filled_quantity = event.filled_quantity
order = self._find_pending_order(event.order_no)
position = self._find_position(event.asset_code)
position_event_type = "modified"
if order:
order.filled_quantity += event.filled_quantity
order.pending_quantity -= event.filled_quantity
total_filled_quantity = order.filled_quantity
self.logger.debug(f"execution event: order_no={event.order_no} filled={total_filled_quantity} total={event.quantity} filled_price={event.filled_price}")
if side == OrderSide.SELL:
if position:
position.quantity -= event.filled_quantity
if position.quantity == 0:
self.positions.remove(position)
position_event_type = "removed"
else:
raise Exception("Position not found")
elif side == OrderSide.BUY:
if position:
position.average_purchase_price = self._recalc_average_purchase_price(asset_code, event.filled_quantity, event.filled_price)
position.quantity += event.filled_quantity
position.sell_possible_quantity += event.filled_quantity
else:
try:
asset_info = self.tickers[asset_code]
except KeyError:
self.logger.warn(f"Ticker info not found for asset code: {asset_code}")
asset_info = self.last_t_day_tickers.get(asset_code, {})
position = StockPosition(
asset_code=asset_code,
asset_name=asset_info.get("name", ""),
quantity=event.filled_quantity,
sell_possible_quantity=event.filled_quantity,
average_purchase_price=Decimal(event.filled_price),
)
self.positions.append(position)
position_event_type = "added"
partial = event.quantity != total_filled_quantity
if not partial and order:
self.pending_orders.remove(order)
self._save_trading_history(
asset_code,
side,
order_no,
event.filled_price,
total_filled_quantity,
position.average_purchase_price,
event.filled_time,
partial,
)
if order is None:
self.logger.warn(f"Order not found for order_no: {order_no}")
order = self._find_order_in_today_orders(order_no)
self.logger.warn(f"Order found in today's orders: {order}")
self._notify_pending_order_update("partial" if partial else "completed", order)
self._notify_position_update(position_event_type, position)
def _refresh_positions(self):
self.positions = self.simple_api.get_positions()
def _refresh_pending_orders(self):
old_pending_orders = {}
for o in self.pending_orders:
old_pending_orders[o.order_no] = o
self.pending_orders = self.simple_api.get_pending_orders(exchanges=list(OrderExchange))
for o in self.pending_orders:
old_pending_order = old_pending_orders.get(o.order_no)
if old_pending_order is not None:
o.average_purchase_price = old_pending_order.average_purchase_price
def _notify_pending_order_update(self, status: str, order: StockOrder):
# deprecated
if self.on_pending_order_update is not None:
self.on_pending_order_update(status, order)
for callback in self.on_pending_order_update_callback_dict.values():
callback(status, order)
def _notify_position_update(self, type: str, position: StockPosition = None):
# deprecated
if self.on_position_update is not None:
self.on_position_update(type, position)
for callback in self.on_position_update_callback_dict.values():
callback(type, position)
def add_pending_order_update_callback(self, callback, callback_id=None):
if callback_id is None:
callback_id = self.callback_id
self.callback_id += 1
self.on_pending_order_update_callback_dict[callback_id] = callback
return callback_id
def remove_pending_order_update_callback(self, callback_id):
if callback_id is not None and callback_id in self.on_pending_order_update_callback_dict:
del self.on_pending_order_update_callback_dict[callback_id]
def add_position_update_callback(self, callback, callback_id=None):
if callback_id is None:
callback_id = self.callback_id
self.callback_id += 1
self.on_position_update_callback_dict[callback_id] = callback
return callback_id
def remove_position_update_callback(self, callback_id):
if callback_id is not None and callback_id in self.on_position_update_callback_dict:
del self.on_position_update_callback_dict[callback_id]
def _save_trading_history(
self,
asset_code: str,
side: OrderSide,
order_no: str,
filled_price: int,
filled_quantity: int,
average_purchase_price: Decimal = None,
executed_time: dtm.datetime = None,
partial: bool = False,
):
try:
asset_info = self.tickers[asset_code]
except KeyError:
self.logger.warn(f"[_save_trading_history] Ticker info not found for asset code: {asset_code}")
asset_info = self.last_t_day_tickers.get(asset_code, {})
is_equity = asset_info.get("type", "EQUITY") == "EQUITY"
fee = filled_price * filled_quantity * self.fee_rate
tax = 0
pnl = None
pnl_rate = None
if side == OrderSide.SELL:
sell_value = filled_price * filled_quantity
if is_equity:
tax = sell_value * self.tax_rate
buy_value = average_purchase_price * filled_quantity
buy_fee = buy_value * self.fee_rate
pnl = sell_value - buy_value - fee - tax - buy_fee
pnl_rate = pnl / buy_value * 100 if buy_value != 0 else 0
data = TradingHistory(
date=dtm.date.today().strftime("%Y%m%d"),
order_no=order_no,
side="buy" if side == OrderSide.BUY else "sell",
asset_code=asset_code,
quantity=filled_quantity,
filled_price=filled_price,
average_purchase_price=float(average_purchase_price),
tax=float(tax) if tax is not None else None,
fee=float(fee) if fee is not None else None,
pnl=float(pnl) if pnl is not None else None,
pnl_rate=float(pnl_rate) if pnl_rate is not None else None,
executed_time=(int(executed_time.timestamp() * 1000) if executed_time else None),
partial=partial,
)
self._send_trading_history(data)
def _send_trading_history(self, history: TradingHistory):
if not self.save_trading_history:
return
url = f"{c.PYQQQ_API_URL}/analytics/trades/{history.order_no}"
data = asdict(history)
data["brokerage"] = "ebest" if isinstance(self.simple_api, EBestSimpleDomesticStock) else "kis"
data["account_no"] = self.account_no
strategy_name = os.getenv("STRATEGY_NAME")
if strategy_name is not None:
data["strategy_name"] = strategy_name
positions = []
for p in self.positions:
d = asdict(p)
d["average_purchase_price"] = float(d["average_purchase_price"])
positions.append(d)
r = send_request("POST", url, json=data)
raise_for_status(r)
self.logger.info(f"save trading history: {data}")
def save_positions(self):
if not self.save_trading_history:
return
strategy_name = os.getenv("STRATEGY_NAME")
if strategy_name is None:
return
url = f"{c.PYQQQ_API_URL}/analytics/positions"
positions = []
for p in self.simple_api.get_positions():
d = asdict(p)
d["average_purchase_price"] = float(d["average_purchase_price"])
d["current_pnl"] = float(d["current_pnl"]) if "current_pnl" in d else None
positions.append(d)
account = self.simple_api.get_account()
account["pnl_rate"] = float(account["pnl_rate"]) if "pnl_rate" in account else None
req_body = {
"date": dtm.date.today().strftime("%Y%m%d"),
"brokerage": ("ebest" if isinstance(self.simple_api, EBestSimpleDomesticStock) else "kis"),
"account_no": self.account_no,
"positions": positions,
"account": account,
"strategy_name": strategy_name,
}
print(req_body)
r = send_request("POST", url, json=req_body)
raise_for_status(r)
self.logger.info(f"save positions: {positions}")