longport/quote/
core.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4};
5
6use comfy_table::Table;
7use itertools::Itertools;
8use longport_candlesticks::{TradeSessionType, UpdateAction};
9use longport_httpcli::HttpClient;
10use longport_proto::quote::{
11    self, AdjustType, MarketTradeDayRequest, MarketTradeDayResponse, MultiSecurityRequest, Period,
12    PushQuoteTag, SecurityCandlestickRequest, SecurityCandlestickResponse,
13    SecurityStaticInfoResponse, SubscribeRequest, UnsubscribeRequest,
14};
15use longport_wscli::{
16    CodecType, Platform, ProtocolVersion, RateLimit, WsClient, WsClientError, WsEvent, WsSession,
17};
18use time::{Date, OffsetDateTime};
19use time_tz::OffsetDateTimeExt;
20use tokio::{
21    sync::{mpsc, oneshot},
22    time::{Duration, Instant},
23};
24
25use crate::{
26    Config, Error, Market, Result,
27    config::PushCandlestickMode,
28    quote::{
29        Candlestick, PushCandlestick, PushEvent, PushEventDetail, PushQuote, PushTrades,
30        RealtimeQuote, SecurityBoard, SecurityBrokers, SecurityDepth, Subscription, Trade,
31        TradeSession, TradeSessions, cmd_code,
32        store::{Candlesticks, Store, TailCandlestick, get_market},
33        sub_flags::SubFlags,
34        types::QuotePackageDetail,
35        utils::{convert_trade_session, format_date, parse_date},
36    },
37};
38
39const RECONNECT_DELAY: Duration = Duration::from_secs(2);
40
41pub(crate) enum Command {
42    Request {
43        command_code: u8,
44        body: Vec<u8>,
45        reply_tx: oneshot::Sender<Result<Vec<u8>>>,
46    },
47    Subscribe {
48        symbols: Vec<String>,
49        sub_types: SubFlags,
50        is_first_push: bool,
51        reply_tx: oneshot::Sender<Result<()>>,
52    },
53    Unsubscribe {
54        symbols: Vec<String>,
55        sub_types: SubFlags,
56        reply_tx: oneshot::Sender<Result<()>>,
57    },
58    SubscribeCandlesticks {
59        symbol: String,
60        period: Period,
61        trade_sessions: TradeSessions,
62        reply_tx: oneshot::Sender<Result<Vec<Candlestick>>>,
63    },
64    UnsubscribeCandlesticks {
65        symbol: String,
66        period: Period,
67        reply_tx: oneshot::Sender<Result<()>>,
68    },
69    Subscriptions {
70        reply_tx: oneshot::Sender<Vec<Subscription>>,
71    },
72    GetRealtimeQuote {
73        symbols: Vec<String>,
74        reply_tx: oneshot::Sender<Vec<RealtimeQuote>>,
75    },
76    GetRealtimeDepth {
77        symbol: String,
78        reply_tx: oneshot::Sender<SecurityDepth>,
79    },
80    GetRealtimeTrade {
81        symbol: String,
82        count: usize,
83        reply_tx: oneshot::Sender<Vec<Trade>>,
84    },
85    GetRealtimeBrokers {
86        symbol: String,
87        reply_tx: oneshot::Sender<SecurityBrokers>,
88    },
89    GetRealtimeCandlesticks {
90        symbol: String,
91        period: Period,
92        count: usize,
93        reply_tx: oneshot::Sender<Vec<Candlestick>>,
94    },
95}
96
97#[derive(Debug, Default)]
98struct TradingDays {
99    normal_days: HashMap<Market, HashSet<Date>>,
100    half_days: HashMap<Market, HashSet<Date>>,
101}
102
103impl TradingDays {
104    #[inline]
105    fn half_days(&self, market: Market) -> Days {
106        Days(self.half_days.get(&market))
107    }
108}
109
110#[derive(Debug, Copy, Clone)]
111struct Days<'a>(Option<&'a HashSet<Date>>);
112
113impl longport_candlesticks::Days for Days<'_> {
114    #[inline]
115    fn contains(&self, date: Date) -> bool {
116        match self.0 {
117            Some(days) => days.contains(&date),
118            None => false,
119        }
120    }
121}
122
123#[derive(Debug)]
124pub(crate) struct MarketPackageDetail {
125    pub(crate) market: String,
126    pub(crate) packages: Vec<QuotePackageDetail>,
127    pub(crate) warning: String,
128}
129
130pub(crate) struct Core {
131    config: Arc<Config>,
132    rate_limit: Vec<(u8, RateLimit)>,
133    command_rx: mpsc::UnboundedReceiver<Command>,
134    push_tx: mpsc::UnboundedSender<PushEvent>,
135    event_tx: mpsc::UnboundedSender<WsEvent>,
136    event_rx: mpsc::UnboundedReceiver<WsEvent>,
137    http_cli: HttpClient,
138    ws_cli: WsClient,
139    session: Option<WsSession>,
140    close: bool,
141    subscriptions: HashMap<String, SubFlags>,
142    trading_days: TradingDays,
143    store: Store,
144    member_id: i64,
145    quote_level: String,
146    quote_package_details: Vec<QuotePackageDetail>,
147    push_candlestick_mode: PushCandlestickMode,
148}
149
150impl Core {
151    pub(crate) async fn try_new(
152        config: Arc<Config>,
153        command_rx: mpsc::UnboundedReceiver<Command>,
154        push_tx: mpsc::UnboundedSender<PushEvent>,
155    ) -> Result<Self> {
156        let http_cli = config.create_http_client();
157        let otp = http_cli.get_otp_v2().await?;
158
159        let (event_tx, event_rx) = mpsc::unbounded_channel();
160
161        tracing::info!("connecting to quote server");
162        let (url, res) = config.create_quote_ws_request().await;
163        let request = res.map_err(WsClientError::from)?;
164
165        let mut ws_cli = WsClient::open(
166            request,
167            ProtocolVersion::Version1,
168            CodecType::Protobuf,
169            Platform::OpenAPI,
170            event_tx.clone(),
171            vec![],
172        )
173        .await?;
174
175        tracing::info!(url = url, "quote server connected");
176
177        let session = ws_cli.request_auth(otp, config.create_metadata()).await?;
178
179        // fetch user profile
180        let resp = ws_cli
181            .request::<_, quote::UserQuoteProfileResponse>(
182                cmd_code::QUERY_USER_QUOTE_PROFILE,
183                None,
184                quote::UserQuoteProfileRequest {
185                    language: config.language.to_string(),
186                },
187            )
188            .await?;
189        let member_id = resp.member_id;
190        let quote_level = resp.quote_level;
191        let (quote_package_details, quote_package_details_by_market) = resp
192            .quote_level_detail
193            .map(|details| {
194                Ok::<_, Error>((
195                    details
196                        .by_package_key
197                        .into_values()
198                        .map(TryInto::try_into)
199                        .collect::<Result<Vec<_>>>()?,
200                    details
201                        .by_market_code
202                        .into_iter()
203                        .map(|(market, market_packages)| {
204                            Ok(MarketPackageDetail {
205                                market,
206                                packages: market_packages
207                                    .packages
208                                    .into_iter()
209                                    .map(TryInto::try_into)
210                                    .collect::<Result<Vec<_>>>()?,
211                                warning: market_packages.warning_msg,
212                            })
213                        })
214                        .collect::<Result<Vec<_>>>()?,
215                ))
216            })
217            .transpose()?
218            .unwrap_or_default();
219        let rate_limit: Vec<(u8, RateLimit)> = resp
220            .rate_limit
221            .iter()
222            .map(|config| {
223                (
224                    config.command as u8,
225                    RateLimit {
226                        interval: Duration::from_secs(1),
227                        initial: config.burst as usize,
228                        max: config.burst as usize,
229                        refill: config.limit as usize,
230                    },
231                )
232            })
233            .collect();
234        ws_cli.set_rate_limit(rate_limit.clone());
235
236        let current_trade_days = fetch_trading_days(&ws_cli).await?;
237        let push_candlestick_mode = config.push_candlestick_mode.unwrap_or_default();
238
239        let mut table = Table::new();
240        for market_packages in quote_package_details_by_market {
241            if market_packages.warning.is_empty() {
242                table.add_row(vec![
243                    market_packages.market,
244                    market_packages
245                        .packages
246                        .into_iter()
247                        .map(|package| package.name)
248                        .join(", "),
249                ]);
250            } else {
251                table.add_row(vec![market_packages.market, market_packages.warning]);
252            }
253        }
254
255        if config.enable_print_quote_packages {
256            println!("{}", table);
257        }
258
259        tracing::info!(
260            member_id = member_id,
261            quote_level = quote_level,
262            quote_package_details = ?quote_package_details,
263            "quote context initialized",
264        );
265
266        Ok(Self {
267            config,
268            rate_limit,
269            command_rx,
270            push_tx,
271            event_tx,
272            event_rx,
273            http_cli,
274            ws_cli,
275            session: Some(session),
276            close: false,
277            subscriptions: HashMap::new(),
278            trading_days: current_trade_days,
279            store: Store::default(),
280            member_id,
281            quote_level,
282            quote_package_details,
283            push_candlestick_mode,
284        })
285    }
286
287    #[inline]
288    pub(crate) fn member_id(&self) -> i64 {
289        self.member_id
290    }
291
292    #[inline]
293    pub(crate) fn quote_level(&self) -> &str {
294        &self.quote_level
295    }
296
297    #[inline]
298    pub(crate) fn quote_package_details(&self) -> &[QuotePackageDetail] {
299        &self.quote_package_details
300    }
301
302    pub(crate) async fn run(mut self) {
303        while !self.close {
304            match self.main_loop().await {
305                Ok(()) => return,
306                Err(err) => tracing::error!(error = %err, "quote disconnected"),
307            }
308
309            loop {
310                // reconnect
311                tokio::time::sleep(RECONNECT_DELAY).await;
312
313                tracing::info!("connecting to quote server");
314                let (url, res) = self.config.create_quote_ws_request().await;
315                let request = res.expect("BUG: failed to create quote ws request");
316
317                match WsClient::open(
318                    request,
319                    ProtocolVersion::Version1,
320                    CodecType::Protobuf,
321                    Platform::OpenAPI,
322                    self.event_tx.clone(),
323                    self.rate_limit.clone(),
324                )
325                .await
326                {
327                    Ok(ws_cli) => self.ws_cli = ws_cli,
328                    Err(err) => {
329                        tracing::error!(error = %err, "failed to connect quote server");
330                        continue;
331                    }
332                }
333
334                tracing::info!(url = url, "quote server connected");
335
336                // request new session
337                match &self.session {
338                    Some(session) if !session.is_expired() => {
339                        match self
340                            .ws_cli
341                            .request_reconnect(&session.session_id, self.config.create_metadata())
342                            .await
343                        {
344                            Ok(new_session) => self.session = Some(new_session),
345                            Err(err) => {
346                                self.session = None; // invalid session
347                                tracing::error!(error = %err, "failed to request session id");
348                                continue;
349                            }
350                        }
351                    }
352                    _ => {
353                        let otp = match self.http_cli.get_otp_v2().await {
354                            Ok(otp) => otp,
355                            Err(err) => {
356                                tracing::error!(error = %err, "failed to request otp");
357                                continue;
358                            }
359                        };
360
361                        match self
362                            .ws_cli
363                            .request_auth(otp, self.config.create_metadata())
364                            .await
365                        {
366                            Ok(new_session) => self.session = Some(new_session),
367                            Err(err) => {
368                                tracing::error!(error = %err, "failed to request session id");
369                                continue;
370                            }
371                        }
372                    }
373                }
374
375                // handle reconnect
376                match self.resubscribe().await {
377                    Ok(()) => break,
378                    Err(err) => {
379                        tracing::error!(error = %err, "failed to subscribe topics");
380                        continue;
381                    }
382                }
383            }
384        }
385    }
386
387    async fn main_loop(&mut self) -> Result<()> {
388        let mut update_trading_days_interval = tokio::time::interval_at(
389            Instant::now() + Duration::from_secs(60 * 60 * 24),
390            Duration::from_secs(60 * 60 * 24),
391        );
392
393        loop {
394            tokio::select! {
395                item = self.event_rx.recv() => {
396                    match item {
397                        Some(event) => self.handle_ws_event(event).await?,
398                        None => unreachable!(),
399                    }
400                }
401                item = self.command_rx.recv() => {
402                    match item {
403                        Some(command) => self.handle_command(command).await?,
404                        None => {
405                            self.close = true;
406                            return Ok(());
407                        }
408                    }
409                }
410                _ = update_trading_days_interval.tick() => {
411                    if let Ok(days) = fetch_trading_days(&self.ws_cli).await {
412                        self.trading_days = days;
413                    }
414                }
415            }
416        }
417    }
418
419    async fn handle_command(&mut self, command: Command) -> Result<()> {
420        match command {
421            Command::Request {
422                command_code,
423                body,
424                reply_tx,
425            } => self.handle_request(command_code, body, reply_tx).await,
426            Command::Subscribe {
427                symbols,
428                sub_types,
429                is_first_push,
430                reply_tx,
431            } => {
432                let res = self
433                    .handle_subscribe(symbols, sub_types, is_first_push)
434                    .await;
435                let _ = reply_tx.send(res);
436                Ok(())
437            }
438            Command::Unsubscribe {
439                symbols,
440                sub_types,
441                reply_tx,
442            } => {
443                let _ = reply_tx.send(self.handle_unsubscribe(symbols, sub_types).await);
444                Ok(())
445            }
446            Command::SubscribeCandlesticks {
447                symbol,
448                period,
449                trade_sessions,
450                reply_tx,
451            } => {
452                let _ = reply_tx.send(
453                    self.handle_subscribe_candlesticks(symbol, period, trade_sessions)
454                        .await,
455                );
456                Ok(())
457            }
458            Command::UnsubscribeCandlesticks {
459                symbol,
460                period,
461                reply_tx,
462            } => {
463                let _ = reply_tx.send(self.handle_unsubscribe_candlesticks(symbol, period).await);
464                Ok(())
465            }
466            Command::Subscriptions { reply_tx } => {
467                let res = self.handle_subscriptions().await;
468                let _ = reply_tx.send(res);
469                Ok(())
470            }
471            Command::GetRealtimeQuote { symbols, reply_tx } => {
472                let _ = reply_tx.send(self.handle_get_realtime_quote(symbols));
473                Ok(())
474            }
475            Command::GetRealtimeDepth { symbol, reply_tx } => {
476                let _ = reply_tx.send(self.handle_get_realtime_depth(symbol));
477                Ok(())
478            }
479            Command::GetRealtimeTrade {
480                symbol,
481                count,
482                reply_tx,
483            } => {
484                let _ = reply_tx.send(self.handle_get_realtime_trades(symbol, count));
485                Ok(())
486            }
487            Command::GetRealtimeBrokers { symbol, reply_tx } => {
488                let _ = reply_tx.send(self.handle_get_realtime_brokers(symbol));
489                Ok(())
490            }
491            Command::GetRealtimeCandlesticks {
492                symbol,
493                period,
494                count,
495                reply_tx,
496            } => {
497                let _ = reply_tx.send(self.handle_get_realtime_candlesticks(symbol, period, count));
498                Ok(())
499            }
500        }
501    }
502
503    async fn handle_request(
504        &mut self,
505        command_code: u8,
506        body: Vec<u8>,
507        reply_tx: oneshot::Sender<Result<Vec<u8>>>,
508    ) -> Result<()> {
509        let res = self.ws_cli.request_raw(command_code, None, body).await;
510        let _ = reply_tx.send(res.map_err(Into::into));
511        Ok(())
512    }
513
514    async fn handle_subscribe(
515        &mut self,
516        symbols: Vec<String>,
517        sub_types: SubFlags,
518        is_first_push: bool,
519    ) -> Result<()> {
520        // send request
521        let req = SubscribeRequest {
522            symbol: symbols.clone(),
523            sub_type: sub_types.into(),
524            is_first_push,
525        };
526        self.ws_cli
527            .request::<_, ()>(cmd_code::SUBSCRIBE, None, req)
528            .await?;
529
530        // update subscriptions
531        for symbol in symbols {
532            self.subscriptions
533                .entry(symbol)
534                .and_modify(|flags| *flags |= sub_types)
535                .or_insert(sub_types);
536        }
537
538        Ok(())
539    }
540
541    async fn handle_unsubscribe(
542        &mut self,
543        symbols: Vec<String>,
544        sub_types: SubFlags,
545    ) -> Result<()> {
546        tracing::info!(symbols = ?symbols, sub_types = ?sub_types, "unsubscribe");
547
548        // send requests
549        let mut st_group: HashMap<SubFlags, Vec<&str>> = HashMap::new();
550
551        for symbol in &symbols {
552            let mut st = sub_types;
553
554            if let Some(candlesticks) = self
555                .store
556                .securities
557                .get(symbol)
558                .map(|data| &data.candlesticks)
559            {
560                if !candlesticks.is_empty() {
561                    st.remove(SubFlags::QUOTE | SubFlags::TRADE);
562                }
563            }
564
565            if !st.is_empty() {
566                st_group.entry(st).or_default().push(symbol.as_ref());
567            }
568        }
569
570        let requests = st_group
571            .iter()
572            .map(|(st, symbols)| UnsubscribeRequest {
573                symbol: symbols.iter().map(ToString::to_string).collect(),
574                sub_type: (*st).into(),
575                unsub_all: false,
576            })
577            .collect::<Vec<_>>();
578
579        for req in requests {
580            self.ws_cli
581                .request::<_, ()>(cmd_code::UNSUBSCRIBE, None, req)
582                .await?;
583        }
584
585        // update subscriptions
586        let mut remove_symbols = Vec::new();
587        for symbol in &symbols {
588            if let Some(cur_flags) = self.subscriptions.get_mut(symbol) {
589                *cur_flags &= !sub_types;
590                if cur_flags.is_empty() {
591                    remove_symbols.push(symbol);
592                }
593            }
594        }
595
596        for symbol in remove_symbols {
597            self.subscriptions.remove(symbol);
598        }
599        Ok(())
600    }
601
602    async fn handle_subscribe_candlesticks(
603        &mut self,
604        symbol: String,
605        period: Period,
606        trade_sessions: TradeSessions,
607    ) -> Result<Vec<Candlestick>> {
608        tracing::info!(symbol = symbol, period = ?period, "subscribe candlesticks");
609
610        if let Some(candlesticks) = self
611            .store
612            .securities
613            .get_mut(&symbol)
614            .and_then(|data| data.candlesticks.get_mut(&period))
615            .filter(|candlesticks| candlesticks.trade_sessions == trade_sessions)
616        {
617            candlesticks.trade_sessions = trade_sessions;
618            tracing::info!(symbol = symbol, period = ?period, trade_sessions = ?trade_sessions, "subscribed, returns candlesticks in memory");
619            return Ok(candlesticks.candlesticks.clone());
620        }
621
622        tracing::info!(symbol = symbol, "fetch symbol board");
623
624        let security_data = self.store.securities.entry(symbol.clone()).or_default();
625        if security_data.board != SecurityBoard::Unknown {
626            // update board
627            let resp: SecurityStaticInfoResponse = self
628                .ws_cli
629                .request(
630                    cmd_code::GET_BASIC_INFO,
631                    None,
632                    MultiSecurityRequest {
633                        symbol: vec![symbol.clone()],
634                    },
635                )
636                .await?;
637            if resp.secu_static_info.is_empty() {
638                return Err(Error::InvalidSecuritySymbol {
639                    symbol: symbol.clone(),
640                });
641            }
642            security_data.board = resp.secu_static_info[0].board.parse().unwrap_or_default();
643        }
644
645        tracing::info!(symbol = symbol, board = ?security_data.board, "got the symbol board");
646
647        let Some(market) = parse_market_from_symbol(&symbol)
648            .and_then(|market| get_market(market, security_data.board))
649        else {
650            return Err(Error::UnknownMarket { symbol });
651        };
652
653        // pull candlesticks
654        tracing::info!(symbol = symbol, period = ?period, "pull history candlesticks");
655        let resp: SecurityCandlestickResponse = self
656            .ws_cli
657            .request(
658                cmd_code::GET_SECURITY_CANDLESTICKS,
659                None,
660                SecurityCandlestickRequest {
661                    symbol: symbol.clone(),
662                    period: period.into(),
663                    count: 1000,
664                    adjust_type: AdjustType::NoAdjust.into(),
665                    trade_session: trade_sessions as i32,
666                },
667            )
668            .await?;
669        tracing::info!(symbol = symbol, period = ?period, len = resp.candlesticks.len(), "got history candlesticks");
670
671        let mut candlesticks = vec![];
672        let mut tails = HashMap::new();
673
674        for candlestick in resp.candlesticks {
675            let time = OffsetDateTime::from_unix_timestamp(candlestick.timestamp)
676                .map_err(|err| Error::parse_field_error("timestamp", err))?
677                .to_timezone(market.timezone);
678            let ts = match market.candlestick_trade_session(time) {
679                Some(ts) => ts,
680                None => {
681                    tracing::error!(
682                        symbol = symbol,
683                        time = time
684                            .format(&time::format_description::well_known::Rfc3339)
685                            .unwrap(),
686                        "unknown trade session"
687                    );
688                    return Err(Error::UnknownTradeSession { symbol, time });
689                }
690            };
691            let candlestick = candlestick.try_into()?;
692            let index = candlesticks.len();
693            candlesticks.push(candlestick);
694            tails.insert(ts, TailCandlestick { index, candlestick });
695        }
696
697        tracing::info!(symbol = symbol, period = ?period, count = candlesticks.len(), tails = ?tails, "candlesticks loaded");
698
699        security_data
700            .candlesticks
701            .entry(period)
702            .or_insert_with(|| Candlesticks {
703                trade_sessions,
704                candlesticks: candlesticks.clone(),
705                tails,
706            });
707
708        // subscribe
709        if self
710            .subscriptions
711            .get(&symbol)
712            .copied()
713            .unwrap_or_else(SubFlags::empty)
714            .contains(SubFlags::QUOTE | SubFlags::TRADE)
715        {
716            return Ok(candlesticks);
717        }
718
719        tracing::info!(symbol = symbol, period = ?period, "subscribe quote for candlesticks");
720
721        let req = SubscribeRequest {
722            symbol: vec![symbol.clone()],
723            sub_type: (SubFlags::QUOTE | SubFlags::TRADE).into(),
724            is_first_push: true,
725        };
726        self.ws_cli
727            .request::<_, ()>(cmd_code::SUBSCRIBE, None, req)
728            .await?;
729
730        tracing::info!(symbol = symbol, period = ?period, "subscribed quote for candlesticks");
731        Ok(candlesticks)
732    }
733
734    async fn handle_unsubscribe_candlesticks(
735        &mut self,
736        symbol: String,
737        period: Period,
738    ) -> Result<()> {
739        if let Some(periods) = self
740            .store
741            .securities
742            .get_mut(&symbol)
743            .map(|data| &mut data.candlesticks)
744        {
745            periods.remove(&period);
746
747            let sub_flags = self
748                .subscriptions
749                .get(&symbol)
750                .copied()
751                .unwrap_or_else(SubFlags::empty);
752
753            if periods.is_empty() && !sub_flags.intersects(SubFlags::QUOTE | SubFlags::TRADE) {
754                tracing::info!(symbol = symbol, "unsubscribe quote for candlesticks");
755                self.ws_cli
756                    .request::<_, ()>(
757                        cmd_code::UNSUBSCRIBE,
758                        None,
759                        UnsubscribeRequest {
760                            symbol: vec![symbol],
761                            sub_type: (SubFlags::QUOTE | SubFlags::TRADE).into(),
762                            unsub_all: false,
763                        },
764                    )
765                    .await?;
766            }
767        }
768
769        Ok(())
770    }
771
772    async fn handle_subscriptions(&mut self) -> Vec<Subscription> {
773        let mut subscriptions = HashMap::new();
774
775        for (symbol, sub_flags) in &self.subscriptions {
776            if sub_flags.is_empty() {
777                continue;
778            }
779
780            subscriptions.insert(
781                symbol.clone(),
782                Subscription {
783                    symbol: symbol.clone(),
784                    sub_types: *sub_flags,
785                    candlesticks: vec![],
786                },
787            );
788        }
789
790        for (symbol, data) in &self.store.securities {
791            subscriptions
792                .entry(symbol.clone())
793                .or_insert_with(|| Subscription {
794                    symbol: symbol.clone(),
795                    sub_types: SubFlags::empty(),
796                    candlesticks: vec![],
797                })
798                .candlesticks = data.candlesticks.keys().copied().collect();
799        }
800
801        subscriptions.into_values().collect()
802    }
803
804    async fn handle_ws_event(&mut self, event: WsEvent) -> Result<()> {
805        match event {
806            WsEvent::Error(err) => Err(err.into()),
807            WsEvent::Push { command_code, body } => self.handle_push(command_code, body),
808        }
809    }
810
811    async fn resubscribe(&mut self) -> Result<()> {
812        let mut subscriptions: HashMap<SubFlags, HashSet<String>> = HashMap::new();
813
814        for (symbol, flags) in &self.subscriptions {
815            subscriptions
816                .entry(*flags)
817                .or_default()
818                .insert(symbol.clone());
819        }
820
821        for (symbol, data) in &self.store.securities {
822            if !data.candlesticks.is_empty() {
823                subscriptions
824                    .entry(SubFlags::QUOTE | SubFlags::TRADE)
825                    .or_default()
826                    .insert(symbol.clone());
827            }
828        }
829
830        tracing::info!(subscriptions = ?subscriptions, "resubscribe");
831
832        for (flags, symbols) in subscriptions {
833            self.ws_cli
834                .request::<_, ()>(
835                    cmd_code::SUBSCRIBE,
836                    None,
837                    SubscribeRequest {
838                        symbol: symbols.into_iter().collect(),
839                        sub_type: flags.into(),
840                        is_first_push: false,
841                    },
842                )
843                .await?;
844        }
845        Ok(())
846    }
847
848    fn merge_candlesticks_by_quote(&mut self, symbol: &str, push_quote: &PushQuote) {
849        if push_quote.trade_session != TradeSession::Intraday {
850            return;
851        }
852
853        let Some(market_type) = parse_market_from_symbol(symbol) else {
854            return;
855        };
856        let Some(security_data) = self.store.securities.get_mut(symbol) else {
857            return;
858        };
859        let half_days = self.trading_days.half_days(market_type);
860
861        if let Some(candlesticks) = security_data.candlesticks.get_mut(&Period::Day) {
862            let ts = convert_trade_session(push_quote.trade_session);
863            let action = candlesticks.merge_quote(
864                ts,
865                market_type,
866                half_days,
867                security_data.board,
868                Period::Day,
869                push_quote,
870            );
871            update_and_push_candlestick(
872                candlesticks,
873                ts,
874                push_quote.trade_session,
875                symbol,
876                Period::Day,
877                action,
878                self.push_candlestick_mode,
879                &mut self.push_tx,
880            );
881        }
882    }
883
884    fn merge_candlesticks_by_trades(&mut self, symbol: &str, push_trades: &PushTrades) {
885        let Some(market_type) = parse_market_from_symbol(symbol) else {
886            return;
887        };
888        let Some(security_data) = self.store.securities.get_mut(symbol) else {
889            return;
890        };
891        let half_days = self.trading_days.half_days(market_type);
892
893        for trade in &push_trades.trades {
894            let ts = convert_trade_session(trade.trade_session);
895
896            for (period, candlesticks) in &mut security_data.candlesticks {
897                if *period >= Period::Day && !ts.is_intraday() {
898                    continue;
899                }
900
901                let action = candlesticks.merge_trade(
902                    ts,
903                    market_type,
904                    half_days,
905                    security_data.board,
906                    *period,
907                    trade,
908                );
909                update_and_push_candlestick(
910                    candlesticks,
911                    ts,
912                    trade.trade_session,
913                    symbol,
914                    *period,
915                    action,
916                    self.push_candlestick_mode,
917                    &mut self.push_tx,
918                );
919            }
920        }
921    }
922
923    fn handle_push(&mut self, command_code: u8, body: Vec<u8>) -> Result<()> {
924        match PushEvent::parse(command_code, &body) {
925            Ok((mut event, tag)) => {
926                tracing::info!(event = ?event, tag = ?tag, "push event");
927
928                if tag != Some(PushQuoteTag::Eod) {
929                    self.store.handle_push(&mut event);
930                }
931
932                if let PushEventDetail::Quote(push_quote) = &event.detail {
933                    self.merge_candlesticks_by_quote(&event.symbol, push_quote);
934
935                    if !self
936                        .subscriptions
937                        .get(&event.symbol)
938                        .map(|sub_flags| sub_flags.contains(SubFlags::QUOTE))
939                        .unwrap_or_default()
940                    {
941                        return Ok(());
942                    }
943                } else if let PushEventDetail::Trade(trades) = &event.detail {
944                    self.merge_candlesticks_by_trades(&event.symbol, trades);
945
946                    if !self
947                        .subscriptions
948                        .get(&event.symbol)
949                        .map(|sub_flags| sub_flags.contains(SubFlags::TRADE))
950                        .unwrap_or_default()
951                    {
952                        return Ok(());
953                    }
954                }
955
956                if tag == Some(PushQuoteTag::Eod) {
957                    return Ok(());
958                }
959
960                let _ = self.push_tx.send(event);
961            }
962            Err(err) => {
963                tracing::error!(error = %err, "failed to parse push message");
964            }
965        }
966        Ok(())
967    }
968
969    fn handle_get_realtime_quote(&self, symbols: Vec<String>) -> Vec<RealtimeQuote> {
970        let mut result = Vec::new();
971
972        for symbol in symbols {
973            if let Some(data) = self.store.securities.get(&symbol) {
974                result.push(RealtimeQuote {
975                    symbol,
976                    last_done: data.quote.last_done,
977                    open: data.quote.open,
978                    high: data.quote.high,
979                    low: data.quote.low,
980                    timestamp: data.quote.timestamp,
981                    volume: data.quote.volume,
982                    turnover: data.quote.turnover,
983                    trade_status: data.quote.trade_status,
984                });
985            }
986        }
987
988        result
989    }
990
991    fn handle_get_realtime_depth(&self, symbol: String) -> SecurityDepth {
992        let mut result = SecurityDepth::default();
993        if let Some(data) = self.store.securities.get(&symbol) {
994            result.asks.clone_from(&data.asks);
995            result.bids.clone_from(&data.bids);
996        }
997        result
998    }
999
1000    fn handle_get_realtime_trades(&self, symbol: String, count: usize) -> Vec<Trade> {
1001        let mut res = Vec::new();
1002
1003        if let Some(data) = self.store.securities.get(&symbol) {
1004            let trades = if data.trades.len() >= count {
1005                &data.trades[data.trades.len() - count..]
1006            } else {
1007                &data.trades
1008            };
1009            res = trades.to_vec();
1010        }
1011        res
1012    }
1013
1014    fn handle_get_realtime_brokers(&self, symbol: String) -> SecurityBrokers {
1015        let mut result = SecurityBrokers::default();
1016        if let Some(data) = self.store.securities.get(&symbol) {
1017            result.ask_brokers.clone_from(&data.ask_brokers);
1018            result.bid_brokers.clone_from(&data.bid_brokers);
1019        }
1020        result
1021    }
1022
1023    fn handle_get_realtime_candlesticks(
1024        &self,
1025        symbol: String,
1026        period: Period,
1027        count: usize,
1028    ) -> Vec<Candlestick> {
1029        self.store
1030            .securities
1031            .get(&symbol)
1032            .map(|data| &data.candlesticks)
1033            .and_then(|periods| periods.get(&period))
1034            .map(|candlesticks| {
1035                let candlesticks = if candlesticks.candlesticks.len() >= count {
1036                    &candlesticks.candlesticks[candlesticks.candlesticks.len() - count..]
1037                } else {
1038                    &candlesticks.candlesticks
1039                };
1040                candlesticks.to_vec()
1041            })
1042            .unwrap_or_default()
1043    }
1044}
1045
1046async fn fetch_trading_days(cli: &WsClient) -> Result<TradingDays> {
1047    let mut days = TradingDays::default();
1048    let begin_day = OffsetDateTime::now_utc().date() - time::Duration::days(5);
1049    let end_day = begin_day + time::Duration::days(30);
1050
1051    for market in [Market::HK, Market::US, Market::SG, Market::CN] {
1052        let resp = cli
1053            .request::<_, MarketTradeDayResponse>(
1054                cmd_code::GET_TRADING_DAYS,
1055                None,
1056                MarketTradeDayRequest {
1057                    market: market.to_string(),
1058                    beg_day: format_date(begin_day),
1059                    end_day: format_date(end_day),
1060                },
1061            )
1062            .await?;
1063
1064        days.normal_days.insert(
1065            market,
1066            resp.trade_day
1067                .iter()
1068                .map(|value| {
1069                    parse_date(value).map_err(|err| Error::parse_field_error("half_trade_day", err))
1070                })
1071                .collect::<Result<HashSet<_>>>()?,
1072        );
1073
1074        days.half_days.insert(
1075            market,
1076            resp.half_trade_day
1077                .iter()
1078                .map(|value| {
1079                    parse_date(value).map_err(|err| Error::parse_field_error("half_trade_day", err))
1080                })
1081                .collect::<Result<HashSet<_>>>()?,
1082        );
1083    }
1084
1085    Ok(days)
1086}
1087
1088#[allow(clippy::too_many_arguments)]
1089fn update_and_push_candlestick(
1090    candlesticks: &mut Candlesticks,
1091    ts: TradeSessionType,
1092    ts1: TradeSession,
1093    symbol: &str,
1094    period: Period,
1095    action: UpdateAction,
1096    push_candlestick_mode: PushCandlestickMode,
1097    tx: &mut mpsc::UnboundedSender<PushEvent>,
1098) {
1099    let mut push_candlesticks = Vec::new();
1100
1101    match action {
1102        UpdateAction::UpdateLast(candlestick) => {
1103            let tail = candlesticks.tails.get_mut(&ts).unwrap();
1104            candlesticks.candlesticks[tail.index] = (candlestick, ts1).into();
1105            tail.candlestick = (candlestick, ts1).into();
1106
1107            if push_candlestick_mode == PushCandlestickMode::Realtime {
1108                push_candlesticks.push(((candlestick, ts1).into(), false));
1109            }
1110        }
1111        UpdateAction::AppendNew { confirmed, new } => {
1112            let index = if let Some(tail) = candlesticks.tails.get_mut(&ts) {
1113                candlesticks
1114                    .candlesticks
1115                    .insert(tail.index + 1, (new, ts1).into());
1116                tail.index += 1;
1117                tail.candlestick = (new, ts1).into();
1118                tail.index
1119            } else {
1120                let index = candlesticks.insert_candlestick_by_time((new, ts1).into());
1121                candlesticks.tails.insert(
1122                    ts,
1123                    TailCandlestick {
1124                        index,
1125                        candlestick: (new, ts1).into(),
1126                    },
1127                );
1128                index
1129            };
1130
1131            for tail in candlesticks.tails.values_mut() {
1132                if tail.index > index {
1133                    tail.index += 1;
1134                }
1135            }
1136
1137            candlesticks.check_and_remove();
1138
1139            match push_candlestick_mode {
1140                PushCandlestickMode::Realtime => {
1141                    if let Some(confirmed) = confirmed {
1142                        push_candlesticks.push(((confirmed, ts1).into(), true));
1143                    }
1144                    push_candlesticks.push(((new, ts1).into(), false));
1145                }
1146                PushCandlestickMode::Confirmed => {
1147                    if let Some(confirmed) = confirmed {
1148                        push_candlesticks.push(((confirmed, ts1).into(), true));
1149                    }
1150                }
1151            }
1152        }
1153        UpdateAction::None => {}
1154    };
1155
1156    for (candlestick, is_confirmed) in push_candlesticks {
1157        if candlesticks.trade_sessions.contains(ts1) {
1158            tracing::info!(
1159                symbol = symbol,
1160                period = ?period,
1161                is_confirmed = is_confirmed,
1162                candlestick = ?candlestick,
1163                trade_session = ?ts,
1164                "push candlestick"
1165            );
1166            let _ = tx.send(PushEvent {
1167                sequence: 0,
1168                symbol: symbol.to_string(),
1169                detail: PushEventDetail::Candlestick(PushCandlestick {
1170                    period,
1171                    candlestick,
1172                    is_confirmed,
1173                }),
1174            });
1175        }
1176    }
1177}
1178
1179fn parse_market_from_symbol(symbol: &str) -> Option<Market> {
1180    let market = symbol.rfind('.').map(|idx| &symbol[idx + 1..])?;
1181    Some(match market {
1182        "US" => Market::US,
1183        "HK" => Market::HK,
1184        "SG" => Market::SG,
1185        "SH" | "SZ" => Market::CN,
1186        _ => return None,
1187    })
1188}
1189
1190#[cfg(test)]
1191mod tests {
1192    use super::*;
1193
1194    #[test]
1195    fn test_parse_market_from_symbol() {
1196        assert_eq!(parse_market_from_symbol("AAPL.US"), Some(Market::US));
1197        assert_eq!(parse_market_from_symbol("BRK.A.US"), Some(Market::US));
1198    }
1199}