longport/quote/
core.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4};
5
6use comfy_table::Table;
7use itertools::Itertools;
8use longport_candlesticks::{TradeSessionType, UpdateAction};
9use longport_httpcli::HttpClient;
10use longport_proto::quote::{
11    self, AdjustType, MarketTradeDayRequest, MarketTradeDayResponse, MultiSecurityRequest, Period,
12    PushQuoteTag, SecurityCandlestickRequest, SecurityCandlestickResponse,
13    SecurityStaticInfoResponse, SubscribeRequest, UnsubscribeRequest,
14};
15use longport_wscli::{
16    CodecType, Platform, ProtocolVersion, RateLimit, WsClient, WsClientError, WsEvent, WsSession,
17};
18use time::{Date, OffsetDateTime};
19use tokio::{
20    sync::{mpsc, oneshot},
21    time::{Duration, Instant},
22};
23
24use crate::{
25    Config, Error, Market, Result,
26    config::PushCandlestickMode,
27    quote::{
28        Candlestick, PushCandlestick, PushEvent, PushEventDetail, PushQuote, PushTrades,
29        RealtimeQuote, SecurityBoard, SecurityBrokers, SecurityDepth, Subscription, Trade,
30        TradeSession, TradeSessions, cmd_code,
31        store::{Candlesticks, Store, TailCandlestick},
32        sub_flags::SubFlags,
33        types::QuotePackageDetail,
34        utils::{format_date, parse_date},
35    },
36};
37
38const RECONNECT_DELAY: Duration = Duration::from_secs(2);
39
40pub(crate) enum Command {
41    Request {
42        command_code: u8,
43        body: Vec<u8>,
44        reply_tx: oneshot::Sender<Result<Vec<u8>>>,
45    },
46    Subscribe {
47        symbols: Vec<String>,
48        sub_types: SubFlags,
49        reply_tx: oneshot::Sender<Result<()>>,
50    },
51    Unsubscribe {
52        symbols: Vec<String>,
53        sub_types: SubFlags,
54        reply_tx: oneshot::Sender<Result<()>>,
55    },
56    SubscribeCandlesticks {
57        symbol: String,
58        period: Period,
59        trade_sessions: TradeSessions,
60        reply_tx: oneshot::Sender<Result<Vec<Candlestick>>>,
61    },
62    UnsubscribeCandlesticks {
63        symbol: String,
64        period: Period,
65        reply_tx: oneshot::Sender<Result<()>>,
66    },
67    Subscriptions {
68        reply_tx: oneshot::Sender<Vec<Subscription>>,
69    },
70    GetRealtimeQuote {
71        symbols: Vec<String>,
72        reply_tx: oneshot::Sender<Vec<RealtimeQuote>>,
73    },
74    GetRealtimeDepth {
75        symbol: String,
76        reply_tx: oneshot::Sender<SecurityDepth>,
77    },
78    GetRealtimeTrade {
79        symbol: String,
80        count: usize,
81        reply_tx: oneshot::Sender<Vec<Trade>>,
82    },
83    GetRealtimeBrokers {
84        symbol: String,
85        reply_tx: oneshot::Sender<SecurityBrokers>,
86    },
87    GetRealtimeCandlesticks {
88        symbol: String,
89        period: Period,
90        count: usize,
91        reply_tx: oneshot::Sender<Vec<Candlestick>>,
92    },
93}
94
95#[derive(Debug, Default)]
96struct TradingDays {
97    normal_days: HashMap<Market, HashSet<Date>>,
98    half_days: HashMap<Market, HashSet<Date>>,
99}
100
101impl TradingDays {
102    #[inline]
103    fn half_days(&self, market: Market) -> Days<'_> {
104        Days(self.half_days.get(&market))
105    }
106}
107
108#[derive(Debug, Copy, Clone)]
109struct Days<'a>(Option<&'a HashSet<Date>>);
110
111impl longport_candlesticks::Days for Days<'_> {
112    #[inline]
113    fn contains(&self, date: Date) -> bool {
114        match self.0 {
115            Some(days) => days.contains(&date),
116            None => false,
117        }
118    }
119}
120
121#[derive(Debug)]
122pub(crate) struct MarketPackageDetail {
123    pub(crate) market: String,
124    pub(crate) packages: Vec<QuotePackageDetail>,
125    pub(crate) warning: String,
126}
127
128pub(crate) struct Core {
129    config: Arc<Config>,
130    rate_limit: Vec<(u8, RateLimit)>,
131    command_rx: mpsc::UnboundedReceiver<Command>,
132    push_tx: mpsc::UnboundedSender<PushEvent>,
133    event_tx: mpsc::UnboundedSender<WsEvent>,
134    event_rx: mpsc::UnboundedReceiver<WsEvent>,
135    http_cli: HttpClient,
136    ws_cli: WsClient,
137    session: Option<WsSession>,
138    close: bool,
139    subscriptions: HashMap<String, SubFlags>,
140    trading_days: TradingDays,
141    store: Store,
142    member_id: i64,
143    quote_level: String,
144    quote_package_details: Vec<QuotePackageDetail>,
145    push_candlestick_mode: PushCandlestickMode,
146}
147
148impl Core {
149    pub(crate) async fn try_new(
150        config: Arc<Config>,
151        command_rx: mpsc::UnboundedReceiver<Command>,
152        push_tx: mpsc::UnboundedSender<PushEvent>,
153    ) -> Result<Self> {
154        let http_cli = config.create_http_client();
155        let otp = http_cli.get_otp().await?;
156
157        let (event_tx, event_rx) = mpsc::unbounded_channel();
158
159        tracing::info!("connecting to quote server");
160        let (url, res) = config.create_quote_ws_request().await;
161        let request = res.map_err(WsClientError::from)?;
162
163        let mut ws_cli = WsClient::open(
164            request,
165            ProtocolVersion::Version1,
166            CodecType::Protobuf,
167            Platform::OpenAPI,
168            event_tx.clone(),
169            vec![],
170        )
171        .await?;
172
173        tracing::info!(url = url, "quote server connected");
174
175        let session = ws_cli.request_auth(otp, config.create_metadata()).await?;
176
177        // fetch user profile
178        let resp = ws_cli
179            .request::<_, quote::UserQuoteProfileResponse>(
180                cmd_code::QUERY_USER_QUOTE_PROFILE,
181                None,
182                quote::UserQuoteProfileRequest {
183                    language: config.language.to_string(),
184                },
185            )
186            .await?;
187        let member_id = resp.member_id;
188        let quote_level = resp.quote_level;
189        let (quote_package_details, quote_package_details_by_market) = resp
190            .quote_level_detail
191            .map(|details| {
192                Ok::<_, Error>((
193                    details
194                        .by_package_key
195                        .into_values()
196                        .map(TryInto::try_into)
197                        .collect::<Result<Vec<_>>>()?,
198                    details
199                        .by_market_code
200                        .into_iter()
201                        .map(|(market, market_packages)| {
202                            Ok(MarketPackageDetail {
203                                market,
204                                packages: market_packages
205                                    .packages
206                                    .into_iter()
207                                    .map(TryInto::try_into)
208                                    .collect::<Result<Vec<_>>>()?,
209                                warning: market_packages.warning_msg,
210                            })
211                        })
212                        .collect::<Result<Vec<_>>>()?,
213                ))
214            })
215            .transpose()?
216            .unwrap_or_default();
217        let rate_limit: Vec<(u8, RateLimit)> = resp
218            .rate_limit
219            .iter()
220            .map(|config| {
221                (
222                    config.command as u8,
223                    RateLimit {
224                        interval: Duration::from_secs(1),
225                        initial: config.burst as usize,
226                        max: config.burst as usize,
227                        refill: config.limit as usize,
228                    },
229                )
230            })
231            .collect();
232        ws_cli.set_rate_limit(rate_limit.clone());
233
234        let current_trade_days = fetch_trading_days(&ws_cli).await?;
235        let push_candlestick_mode = config.push_candlestick_mode.unwrap_or_default();
236
237        let mut table = Table::new();
238        for market_packages in quote_package_details_by_market {
239            if market_packages.warning.is_empty() {
240                table.add_row(vec![
241                    market_packages.market,
242                    market_packages
243                        .packages
244                        .into_iter()
245                        .map(|package| package.name)
246                        .join(", "),
247                ]);
248            } else {
249                table.add_row(vec![market_packages.market, market_packages.warning]);
250            }
251        }
252
253        if config.enable_print_quote_packages {
254            println!("{table}");
255        }
256
257        tracing::info!(
258            member_id = member_id,
259            quote_level = quote_level,
260            quote_package_details = ?quote_package_details,
261            "quote context initialized",
262        );
263
264        Ok(Self {
265            config,
266            rate_limit,
267            command_rx,
268            push_tx,
269            event_tx,
270            event_rx,
271            http_cli,
272            ws_cli,
273            session: Some(session),
274            close: false,
275            subscriptions: HashMap::new(),
276            trading_days: current_trade_days,
277            store: Store::default(),
278            member_id,
279            quote_level,
280            quote_package_details,
281            push_candlestick_mode,
282        })
283    }
284
285    #[inline]
286    pub(crate) fn member_id(&self) -> i64 {
287        self.member_id
288    }
289
290    #[inline]
291    pub(crate) fn quote_level(&self) -> &str {
292        &self.quote_level
293    }
294
295    #[inline]
296    pub(crate) fn quote_package_details(&self) -> &[QuotePackageDetail] {
297        &self.quote_package_details
298    }
299
300    pub(crate) async fn run(mut self) {
301        while !self.close {
302            match self.main_loop().await {
303                Ok(()) => return,
304                Err(err) => tracing::error!(error = %err, "quote disconnected"),
305            }
306
307            loop {
308                // reconnect
309                tokio::time::sleep(RECONNECT_DELAY).await;
310
311                tracing::info!("connecting to quote server");
312                let (url, res) = self.config.create_quote_ws_request().await;
313                let request = res.expect("BUG: failed to create quote ws request");
314
315                match WsClient::open(
316                    request,
317                    ProtocolVersion::Version1,
318                    CodecType::Protobuf,
319                    Platform::OpenAPI,
320                    self.event_tx.clone(),
321                    self.rate_limit.clone(),
322                )
323                .await
324                {
325                    Ok(ws_cli) => self.ws_cli = ws_cli,
326                    Err(err) => {
327                        tracing::error!(error = %err, "failed to connect quote server");
328                        continue;
329                    }
330                }
331
332                tracing::info!(url = url, "quote server connected");
333
334                // request new session
335                match &self.session {
336                    Some(session) if !session.is_expired() => {
337                        match self
338                            .ws_cli
339                            .request_reconnect(&session.session_id, self.config.create_metadata())
340                            .await
341                        {
342                            Ok(new_session) => self.session = Some(new_session),
343                            Err(err) => {
344                                self.session = None; // invalid session
345                                tracing::error!(error = %err, "failed to request session id");
346                                continue;
347                            }
348                        }
349                    }
350                    _ => {
351                        let otp = match self.http_cli.get_otp().await {
352                            Ok(otp) => otp,
353                            Err(err) => {
354                                tracing::error!(error = %err, "failed to request otp");
355                                continue;
356                            }
357                        };
358
359                        match self
360                            .ws_cli
361                            .request_auth(otp, self.config.create_metadata())
362                            .await
363                        {
364                            Ok(new_session) => self.session = Some(new_session),
365                            Err(err) => {
366                                tracing::error!(error = %err, "failed to request session id");
367                                continue;
368                            }
369                        }
370                    }
371                }
372
373                // handle reconnect
374                match self.resubscribe().await {
375                    Ok(()) => break,
376                    Err(err) => {
377                        tracing::error!(error = %err, "failed to subscribe topics");
378                        continue;
379                    }
380                }
381            }
382        }
383    }
384
385    async fn main_loop(&mut self) -> Result<()> {
386        let mut update_trading_days_interval = tokio::time::interval_at(
387            Instant::now() + Duration::from_secs(60 * 60 * 24),
388            Duration::from_secs(60 * 60 * 24),
389        );
390
391        loop {
392            tokio::select! {
393                item = self.event_rx.recv() => {
394                    match item {
395                        Some(event) => self.handle_ws_event(event).await?,
396                        None => unreachable!(),
397                    }
398                }
399                item = self.command_rx.recv() => {
400                    match item {
401                        Some(command) => self.handle_command(command).await?,
402                        None => {
403                            self.close = true;
404                            return Ok(());
405                        }
406                    }
407                }
408                _ = update_trading_days_interval.tick() => {
409                    if let Ok(days) = fetch_trading_days(&self.ws_cli).await {
410                        self.trading_days = days;
411                    }
412                }
413            }
414        }
415    }
416
417    async fn handle_command(&mut self, command: Command) -> Result<()> {
418        match command {
419            Command::Request {
420                command_code,
421                body,
422                reply_tx,
423            } => self.handle_request(command_code, body, reply_tx).await,
424            Command::Subscribe {
425                symbols,
426                sub_types,
427                reply_tx,
428            } => {
429                let res = self.handle_subscribe(symbols, sub_types).await;
430                let _ = reply_tx.send(res);
431                Ok(())
432            }
433            Command::Unsubscribe {
434                symbols,
435                sub_types,
436                reply_tx,
437            } => {
438                let _ = reply_tx.send(self.handle_unsubscribe(symbols, sub_types).await);
439                Ok(())
440            }
441            Command::SubscribeCandlesticks {
442                symbol,
443                period,
444                trade_sessions,
445                reply_tx,
446            } => {
447                let _ = reply_tx.send(
448                    self.handle_subscribe_candlesticks(symbol, period, trade_sessions)
449                        .await,
450                );
451                Ok(())
452            }
453            Command::UnsubscribeCandlesticks {
454                symbol,
455                period,
456                reply_tx,
457            } => {
458                let _ = reply_tx.send(self.handle_unsubscribe_candlesticks(symbol, period).await);
459                Ok(())
460            }
461            Command::Subscriptions { reply_tx } => {
462                let res = self.handle_subscriptions().await;
463                let _ = reply_tx.send(res);
464                Ok(())
465            }
466            Command::GetRealtimeQuote { symbols, reply_tx } => {
467                let _ = reply_tx.send(self.handle_get_realtime_quote(symbols));
468                Ok(())
469            }
470            Command::GetRealtimeDepth { symbol, reply_tx } => {
471                let _ = reply_tx.send(self.handle_get_realtime_depth(symbol));
472                Ok(())
473            }
474            Command::GetRealtimeTrade {
475                symbol,
476                count,
477                reply_tx,
478            } => {
479                let _ = reply_tx.send(self.handle_get_realtime_trades(symbol, count));
480                Ok(())
481            }
482            Command::GetRealtimeBrokers { symbol, reply_tx } => {
483                let _ = reply_tx.send(self.handle_get_realtime_brokers(symbol));
484                Ok(())
485            }
486            Command::GetRealtimeCandlesticks {
487                symbol,
488                period,
489                count,
490                reply_tx,
491            } => {
492                let _ = reply_tx.send(self.handle_get_realtime_candlesticks(symbol, period, count));
493                Ok(())
494            }
495        }
496    }
497
498    async fn handle_request(
499        &mut self,
500        command_code: u8,
501        body: Vec<u8>,
502        reply_tx: oneshot::Sender<Result<Vec<u8>>>,
503    ) -> Result<()> {
504        let res = self.ws_cli.request_raw(command_code, None, body).await;
505        let _ = reply_tx.send(res.map_err(Into::into));
506        Ok(())
507    }
508
509    async fn handle_subscribe(&mut self, symbols: Vec<String>, sub_types: SubFlags) -> Result<()> {
510        // send request
511        let req = SubscribeRequest {
512            symbol: symbols.clone(),
513            sub_type: sub_types.into(),
514            is_first_push: true,
515        };
516        self.ws_cli
517            .request::<_, ()>(cmd_code::SUBSCRIBE, None, req)
518            .await?;
519
520        // update subscriptions
521        for symbol in symbols {
522            self.subscriptions
523                .entry(symbol)
524                .and_modify(|flags| *flags |= sub_types)
525                .or_insert(sub_types);
526        }
527
528        Ok(())
529    }
530
531    async fn handle_unsubscribe(
532        &mut self,
533        symbols: Vec<String>,
534        sub_types: SubFlags,
535    ) -> Result<()> {
536        tracing::info!(symbols = ?symbols, sub_types = ?sub_types, "unsubscribe");
537
538        // send requests
539        let mut st_group: HashMap<SubFlags, Vec<&str>> = HashMap::new();
540
541        for symbol in &symbols {
542            let mut st = sub_types;
543
544            if let Some(candlesticks) = self
545                .store
546                .securities
547                .get(symbol)
548                .map(|data| &data.candlesticks)
549                && !candlesticks.is_empty()
550            {
551                st.remove(SubFlags::QUOTE | SubFlags::TRADE);
552            }
553
554            if !st.is_empty() {
555                st_group.entry(st).or_default().push(symbol.as_ref());
556            }
557        }
558
559        let requests = st_group
560            .iter()
561            .map(|(st, symbols)| UnsubscribeRequest {
562                symbol: symbols.iter().map(ToString::to_string).collect(),
563                sub_type: (*st).into(),
564                unsub_all: false,
565            })
566            .collect::<Vec<_>>();
567
568        for req in requests {
569            self.ws_cli
570                .request::<_, ()>(cmd_code::UNSUBSCRIBE, None, req)
571                .await?;
572        }
573
574        // update subscriptions
575        let mut remove_symbols = Vec::new();
576        for symbol in &symbols {
577            if let Some(cur_flags) = self.subscriptions.get_mut(symbol) {
578                *cur_flags &= !sub_types;
579                if cur_flags.is_empty() {
580                    remove_symbols.push(symbol);
581                }
582            }
583        }
584
585        for symbol in remove_symbols {
586            self.subscriptions.remove(symbol);
587        }
588        Ok(())
589    }
590
591    async fn handle_subscribe_candlesticks(
592        &mut self,
593        symbol: String,
594        period: Period,
595        trade_sessions: TradeSessions,
596    ) -> Result<Vec<Candlestick>> {
597        tracing::info!(symbol = symbol, period = ?period, "subscribe candlesticks");
598
599        if let Some(candlesticks) = self
600            .store
601            .securities
602            .get_mut(&symbol)
603            .and_then(|data| data.candlesticks.get_mut(&period))
604            .filter(|candlesticks| candlesticks.trade_sessions == trade_sessions)
605        {
606            candlesticks.trade_sessions = trade_sessions;
607            tracing::info!(symbol = symbol, period = ?period, trade_sessions = ?trade_sessions, "subscribed, returns candlesticks in memory");
608            return Ok(candlesticks.candlesticks.clone());
609        }
610
611        tracing::info!(symbol = symbol, "fetch symbol board");
612
613        let security_data = self.store.securities.entry(symbol.clone()).or_default();
614        if security_data.board == SecurityBoard::Unknown {
615            // update board
616            let resp: SecurityStaticInfoResponse = self
617                .ws_cli
618                .request(
619                    cmd_code::GET_BASIC_INFO,
620                    None,
621                    MultiSecurityRequest {
622                        symbol: vec![symbol.clone()],
623                    },
624                )
625                .await?;
626            if resp.secu_static_info.is_empty() {
627                return Err(Error::InvalidSecuritySymbol {
628                    symbol: symbol.clone(),
629                });
630            }
631            security_data.board = resp.secu_static_info[0].board.parse().unwrap_or_default();
632        }
633
634        tracing::info!(symbol = symbol, board = ?security_data.board, "got the symbol board");
635
636        // pull candlesticks
637        tracing::info!(symbol = symbol, period = ?period, "pull history candlesticks");
638        let resp: SecurityCandlestickResponse = self
639            .ws_cli
640            .request(
641                cmd_code::GET_SECURITY_CANDLESTICKS,
642                None,
643                SecurityCandlestickRequest {
644                    symbol: symbol.clone(),
645                    period: period.into(),
646                    count: 1000,
647                    adjust_type: AdjustType::NoAdjust.into(),
648                    trade_session: trade_sessions as i32,
649                },
650            )
651            .await?;
652        tracing::info!(symbol = symbol, period = ?period, len = resp.candlesticks.len(), "got history candlesticks");
653
654        let mut candlesticks = vec![];
655        let mut tails = HashMap::new();
656
657        for candlestick in resp.candlesticks {
658            let candlestick: Candlestick = candlestick.try_into()?;
659            let index = candlesticks.len();
660            candlesticks.push(candlestick);
661            tails.insert(
662                candlestick.trade_session,
663                TailCandlestick { index, candlestick },
664            );
665        }
666
667        tracing::info!(symbol = symbol, period = ?period, count = candlesticks.len(), tails = ?tails, "candlesticks loaded");
668
669        security_data
670            .candlesticks
671            .entry(period)
672            .or_insert_with(|| Candlesticks {
673                trade_sessions,
674                candlesticks: candlesticks.clone(),
675                tails,
676            });
677
678        // subscribe
679        if self
680            .subscriptions
681            .get(&symbol)
682            .copied()
683            .unwrap_or_else(SubFlags::empty)
684            .contains(SubFlags::QUOTE | SubFlags::TRADE)
685        {
686            return Ok(candlesticks);
687        }
688
689        tracing::info!(symbol = symbol, period = ?period, "subscribe quote for candlesticks");
690
691        let req = SubscribeRequest {
692            symbol: vec![symbol.clone()],
693            sub_type: (SubFlags::QUOTE | SubFlags::TRADE).into(),
694            is_first_push: true,
695        };
696        self.ws_cli
697            .request::<_, ()>(cmd_code::SUBSCRIBE, None, req)
698            .await?;
699
700        tracing::info!(symbol = symbol, period = ?period, "subscribed quote for candlesticks");
701        Ok(candlesticks)
702    }
703
704    async fn handle_unsubscribe_candlesticks(
705        &mut self,
706        symbol: String,
707        period: Period,
708    ) -> Result<()> {
709        if let Some(periods) = self
710            .store
711            .securities
712            .get_mut(&symbol)
713            .map(|data| &mut data.candlesticks)
714        {
715            periods.remove(&period);
716
717            let sub_flags = self
718                .subscriptions
719                .get(&symbol)
720                .copied()
721                .unwrap_or_else(SubFlags::empty);
722
723            if periods.is_empty() && !sub_flags.intersects(SubFlags::QUOTE | SubFlags::TRADE) {
724                tracing::info!(symbol = symbol, "unsubscribe quote for candlesticks");
725                self.ws_cli
726                    .request::<_, ()>(
727                        cmd_code::UNSUBSCRIBE,
728                        None,
729                        UnsubscribeRequest {
730                            symbol: vec![symbol],
731                            sub_type: (SubFlags::QUOTE | SubFlags::TRADE).into(),
732                            unsub_all: false,
733                        },
734                    )
735                    .await?;
736            }
737        }
738
739        Ok(())
740    }
741
742    async fn handle_subscriptions(&mut self) -> Vec<Subscription> {
743        let mut subscriptions = HashMap::new();
744
745        for (symbol, sub_flags) in &self.subscriptions {
746            if sub_flags.is_empty() {
747                continue;
748            }
749
750            subscriptions.insert(
751                symbol.clone(),
752                Subscription {
753                    symbol: symbol.clone(),
754                    sub_types: *sub_flags,
755                    candlesticks: vec![],
756                },
757            );
758        }
759
760        for (symbol, data) in &self.store.securities {
761            subscriptions
762                .entry(symbol.clone())
763                .or_insert_with(|| Subscription {
764                    symbol: symbol.clone(),
765                    sub_types: SubFlags::empty(),
766                    candlesticks: vec![],
767                })
768                .candlesticks = data.candlesticks.keys().copied().collect();
769        }
770
771        subscriptions.into_values().collect()
772    }
773
774    async fn handle_ws_event(&mut self, event: WsEvent) -> Result<()> {
775        match event {
776            WsEvent::Error(err) => Err(err.into()),
777            WsEvent::Push { command_code, body } => self.handle_push(command_code, body),
778        }
779    }
780
781    async fn resubscribe(&mut self) -> Result<()> {
782        let mut subscriptions: HashMap<SubFlags, HashSet<String>> = HashMap::new();
783
784        for (symbol, flags) in &self.subscriptions {
785            subscriptions
786                .entry(*flags)
787                .or_default()
788                .insert(symbol.clone());
789        }
790
791        for (symbol, data) in &self.store.securities {
792            if !data.candlesticks.is_empty() {
793                subscriptions
794                    .entry(SubFlags::QUOTE | SubFlags::TRADE)
795                    .or_default()
796                    .insert(symbol.clone());
797            }
798        }
799
800        tracing::info!(subscriptions = ?subscriptions, "resubscribe");
801
802        for (flags, symbols) in subscriptions {
803            self.ws_cli
804                .request::<_, ()>(
805                    cmd_code::SUBSCRIBE,
806                    None,
807                    SubscribeRequest {
808                        symbol: symbols.into_iter().collect(),
809                        sub_type: flags.into(),
810                        is_first_push: false,
811                    },
812                )
813                .await?;
814        }
815        Ok(())
816    }
817
818    fn merge_candlesticks_by_quote(&mut self, symbol: &str, push_quote: &PushQuote) {
819        let Some(market_type) = parse_market_from_symbol(symbol) else {
820            return;
821        };
822        let Some(security_data) = self.store.securities.get_mut(symbol) else {
823            return;
824        };
825
826        if push_quote.trade_session != TradeSession::Intraday {
827            return;
828        }
829
830        let half_days = self.trading_days.half_days(market_type);
831
832        for (period, candlesticks) in &mut security_data.candlesticks {
833            let Some(mtype) = merge_type(security_data.board, push_quote.trade_session, *period)
834            else {
835                continue;
836            };
837
838            let action = if mtype == MergeType::QuoteDay {
839                Some(candlesticks.merge_quote_day(market_type, security_data.board, push_quote))
840            } else if mtype == MergeType::Quote {
841                Some(candlesticks.merge_quote(
842                    market_type,
843                    half_days,
844                    security_data.board,
845                    *period,
846                    push_quote,
847                ))
848            } else {
849                None
850            };
851
852            if let Some(action) = action {
853                update_and_push_candlestick(
854                    candlesticks,
855                    push_quote.trade_session,
856                    symbol,
857                    *period,
858                    action,
859                    self.push_candlestick_mode,
860                    &mut self.push_tx,
861                );
862            }
863        }
864    }
865
866    fn merge_candlesticks_by_trades(&mut self, symbol: &str, push_trades: &PushTrades) {
867        let Some(market_type) = parse_market_from_symbol(symbol) else {
868            return;
869        };
870        let Some(security_data) = self.store.securities.get_mut(symbol) else {
871            return;
872        };
873
874        let half_days = self.trading_days.half_days(market_type);
875
876        for trade in &push_trades.trades {
877            for (period, candlesticks) in &mut security_data.candlesticks {
878                if merge_type(security_data.board, trade.trade_session, *period)
879                    != Some(MergeType::Trade)
880                {
881                    continue;
882                }
883
884                let action = candlesticks.merge_trade(
885                    market_type,
886                    half_days,
887                    security_data.board,
888                    *period,
889                    trade,
890                );
891                update_and_push_candlestick(
892                    candlesticks,
893                    trade.trade_session,
894                    symbol,
895                    *period,
896                    action,
897                    self.push_candlestick_mode,
898                    &mut self.push_tx,
899                );
900            }
901        }
902    }
903
904    fn handle_push(&mut self, command_code: u8, body: Vec<u8>) -> Result<()> {
905        match PushEvent::parse(command_code, &body) {
906            Ok((mut event, tag)) => {
907                tracing::info!(event = ?event, tag = ?tag, "push event");
908
909                if tag != Some(PushQuoteTag::Eod) {
910                    self.store.handle_push(&mut event);
911                }
912
913                if let PushEventDetail::Quote(push_quote) = &event.detail {
914                    self.merge_candlesticks_by_quote(&event.symbol, push_quote);
915
916                    if !self
917                        .subscriptions
918                        .get(&event.symbol)
919                        .map(|sub_flags| sub_flags.contains(SubFlags::QUOTE))
920                        .unwrap_or_default()
921                    {
922                        return Ok(());
923                    }
924                } else if let PushEventDetail::Trade(trades) = &event.detail {
925                    self.merge_candlesticks_by_trades(&event.symbol, trades);
926
927                    if !self
928                        .subscriptions
929                        .get(&event.symbol)
930                        .map(|sub_flags| sub_flags.contains(SubFlags::TRADE))
931                        .unwrap_or_default()
932                    {
933                        return Ok(());
934                    }
935                }
936
937                if tag == Some(PushQuoteTag::Eod) {
938                    return Ok(());
939                }
940
941                let _ = self.push_tx.send(event);
942            }
943            Err(err) => {
944                tracing::error!(error = %err, "failed to parse push message");
945            }
946        }
947        Ok(())
948    }
949
950    fn handle_get_realtime_quote(&self, symbols: Vec<String>) -> Vec<RealtimeQuote> {
951        let mut result = Vec::new();
952
953        for symbol in symbols {
954            if let Some(data) = self.store.securities.get(&symbol) {
955                result.push(RealtimeQuote {
956                    symbol,
957                    last_done: data.quote.last_done,
958                    open: data.quote.open,
959                    high: data.quote.high,
960                    low: data.quote.low,
961                    timestamp: data.quote.timestamp,
962                    volume: data.quote.volume,
963                    turnover: data.quote.turnover,
964                    trade_status: data.quote.trade_status,
965                });
966            }
967        }
968
969        result
970    }
971
972    fn handle_get_realtime_depth(&self, symbol: String) -> SecurityDepth {
973        let mut result = SecurityDepth::default();
974        if let Some(data) = self.store.securities.get(&symbol) {
975            result.asks.clone_from(&data.asks);
976            result.bids.clone_from(&data.bids);
977        }
978        result
979    }
980
981    fn handle_get_realtime_trades(&self, symbol: String, count: usize) -> Vec<Trade> {
982        let mut res = Vec::new();
983
984        if let Some(data) = self.store.securities.get(&symbol) {
985            let trades = if data.trades.len() >= count {
986                &data.trades[data.trades.len() - count..]
987            } else {
988                &data.trades
989            };
990            res = trades.to_vec();
991        }
992        res
993    }
994
995    fn handle_get_realtime_brokers(&self, symbol: String) -> SecurityBrokers {
996        let mut result = SecurityBrokers::default();
997        if let Some(data) = self.store.securities.get(&symbol) {
998            result.ask_brokers.clone_from(&data.ask_brokers);
999            result.bid_brokers.clone_from(&data.bid_brokers);
1000        }
1001        result
1002    }
1003
1004    fn handle_get_realtime_candlesticks(
1005        &self,
1006        symbol: String,
1007        period: Period,
1008        count: usize,
1009    ) -> Vec<Candlestick> {
1010        self.store
1011            .securities
1012            .get(&symbol)
1013            .map(|data| &data.candlesticks)
1014            .and_then(|periods| periods.get(&period))
1015            .map(|candlesticks| {
1016                let candlesticks = if candlesticks.candlesticks.len() >= count {
1017                    &candlesticks.candlesticks[candlesticks.candlesticks.len() - count..]
1018                } else {
1019                    &candlesticks.candlesticks
1020                };
1021                candlesticks.to_vec()
1022            })
1023            .unwrap_or_default()
1024    }
1025}
1026
1027#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1028enum MergeType {
1029    Trade,
1030    QuoteDay,
1031    Quote,
1032}
1033
1034fn merge_type(
1035    board: SecurityBoard,
1036    trade_session: TradeSession,
1037    period: Period,
1038) -> Option<MergeType> {
1039    use Period::*;
1040    use SecurityBoard::*;
1041    use TradeSession::*;
1042
1043    if !trade_session.is_intraday() && period >= Day {
1044        return None;
1045    }
1046
1047    Some(match (board, trade_session, period) {
1048        (
1049            USDJI | USNSDQ | USSector | HKHS | HKSector | CNIX | CNSector | STI | SGSector
1050            | SPXIndex | VIXIndex,
1051            _,
1052            _,
1053        ) => {
1054            if period == Day && trade_session == Intraday {
1055                MergeType::QuoteDay
1056            } else {
1057                MergeType::Quote
1058            }
1059        }
1060        (_, _, Day) if trade_session == Intraday => MergeType::QuoteDay,
1061        _ => MergeType::Trade,
1062    })
1063}
1064
1065async fn fetch_trading_days(cli: &WsClient) -> Result<TradingDays> {
1066    let mut days = TradingDays::default();
1067    let begin_day = OffsetDateTime::now_utc().date() - time::Duration::days(5);
1068    let end_day = begin_day + time::Duration::days(30);
1069
1070    for market in [Market::HK, Market::US, Market::SG, Market::CN] {
1071        let resp = cli
1072            .request::<_, MarketTradeDayResponse>(
1073                cmd_code::GET_TRADING_DAYS,
1074                None,
1075                MarketTradeDayRequest {
1076                    market: market.to_string(),
1077                    beg_day: format_date(begin_day),
1078                    end_day: format_date(end_day),
1079                },
1080            )
1081            .await?;
1082
1083        days.normal_days.insert(
1084            market,
1085            resp.trade_day
1086                .iter()
1087                .map(|value| {
1088                    parse_date(value).map_err(|err| Error::parse_field_error("half_trade_day", err))
1089                })
1090                .collect::<Result<HashSet<_>>>()?,
1091        );
1092
1093        days.half_days.insert(
1094            market,
1095            resp.half_trade_day
1096                .iter()
1097                .map(|value| {
1098                    parse_date(value).map_err(|err| Error::parse_field_error("half_trade_day", err))
1099                })
1100                .collect::<Result<HashSet<_>>>()?,
1101        );
1102    }
1103
1104    Ok(days)
1105}
1106
1107#[allow(clippy::too_many_arguments)]
1108fn update_and_push_candlestick(
1109    candlesticks: &mut Candlesticks,
1110    ts: TradeSession,
1111    symbol: &str,
1112    period: Period,
1113    action: UpdateAction<Candlestick>,
1114    push_candlestick_mode: PushCandlestickMode,
1115    tx: &mut mpsc::UnboundedSender<PushEvent>,
1116) {
1117    let mut push_candlesticks = Vec::new();
1118
1119    match action {
1120        UpdateAction::UpdateLast(candlestick) => {
1121            let tail = candlesticks.tails.get_mut(&ts).unwrap();
1122            candlesticks.candlesticks[tail.index] = candlestick;
1123            tail.candlestick = candlestick;
1124
1125            if push_candlestick_mode == PushCandlestickMode::Realtime {
1126                push_candlesticks.push((candlestick, false));
1127            }
1128        }
1129        UpdateAction::AppendNew { confirmed, new } => {
1130            let index = if let Some(tail) = candlesticks.tails.get_mut(&ts) {
1131                candlesticks.candlesticks.insert(tail.index + 1, new);
1132                tail.index += 1;
1133                tail.candlestick = new;
1134                tail.index
1135            } else {
1136                let index = candlesticks.insert_candlestick_by_time(new);
1137                candlesticks.tails.insert(
1138                    ts,
1139                    TailCandlestick {
1140                        index,
1141                        candlestick: new,
1142                    },
1143                );
1144                index
1145            };
1146
1147            for tail in candlesticks.tails.values_mut() {
1148                if tail.index > index {
1149                    tail.index += 1;
1150                }
1151            }
1152
1153            candlesticks.check_and_remove();
1154
1155            match push_candlestick_mode {
1156                PushCandlestickMode::Realtime => {
1157                    if let Some(confirmed) = confirmed {
1158                        push_candlesticks.push((confirmed, true));
1159                    }
1160                    push_candlesticks.push((new, false));
1161                }
1162                PushCandlestickMode::Confirmed => {
1163                    if let Some(confirmed) = confirmed {
1164                        push_candlesticks.push((confirmed, true));
1165                    }
1166                }
1167            }
1168        }
1169        UpdateAction::None => {}
1170    };
1171
1172    for (candlestick, is_confirmed) in push_candlesticks {
1173        if candlesticks.trade_sessions.contains(ts) {
1174            tracing::info!(
1175                symbol = symbol,
1176                period = ?period,
1177                is_confirmed = is_confirmed,
1178                candlestick = ?candlestick,
1179                trade_session = ?ts,
1180                "push candlestick"
1181            );
1182            let _ = tx.send(PushEvent {
1183                sequence: 0,
1184                symbol: symbol.to_string(),
1185                detail: PushEventDetail::Candlestick(PushCandlestick {
1186                    period,
1187                    candlestick,
1188                    is_confirmed,
1189                }),
1190            });
1191        }
1192    }
1193}
1194
1195fn parse_market_from_symbol(symbol: &str) -> Option<Market> {
1196    let market = symbol.rfind('.').map(|idx| &symbol[idx + 1..])?;
1197    Some(match market {
1198        "US" => Market::US,
1199        "HK" => Market::HK,
1200        "SG" => Market::SG,
1201        "SH" | "SZ" => Market::CN,
1202        _ => return None,
1203    })
1204}
1205
1206#[cfg(test)]
1207mod tests {
1208    use super::*;
1209
1210    #[test]
1211    fn test_parse_market_from_symbol() {
1212        assert_eq!(parse_market_from_symbol("AAPL.US"), Some(Market::US));
1213        assert_eq!(parse_market_from_symbol("BRK.A.US"), Some(Market::US));
1214    }
1215
1216    #[test]
1217    fn test_merge_type() {
1218        use Period::*;
1219        use SecurityBoard::*;
1220        use TradeSession::*;
1221
1222        assert_eq!(merge_type(USDJI, Intraday, Day), Some(MergeType::QuoteDay));
1223        assert_eq!(merge_type(USDJI, Overnight, Day), None);
1224        assert_eq!(
1225            merge_type(USDJI, Intraday, OneMinute),
1226            Some(MergeType::Quote)
1227        );
1228        assert_eq!(
1229            merge_type(USDJI, Overnight, OneMinute),
1230            Some(MergeType::Quote)
1231        );
1232        assert_eq!(merge_type(USDJI, Intraday, Week), Some(MergeType::Quote));
1233        assert_eq!(merge_type(USDJI, Overnight, Week), None);
1234
1235        assert_eq!(merge_type(USMain, Intraday, Day), Some(MergeType::QuoteDay));
1236        assert_eq!(merge_type(USMain, Overnight, Day), None);
1237        assert_eq!(
1238            merge_type(USMain, Intraday, OneMinute),
1239            Some(MergeType::Trade)
1240        );
1241        assert_eq!(
1242            merge_type(USMain, Overnight, OneMinute),
1243            Some(MergeType::Trade)
1244        );
1245        assert_eq!(merge_type(USMain, Intraday, Week), Some(MergeType::Trade));
1246        assert_eq!(merge_type(USMain, Overnight, Week), None);
1247
1248        assert_eq!(merge_type(SPXIndex, Intraday, Year), Some(MergeType::Quote));
1249        assert_eq!(merge_type(VIXIndex, Intraday, Year), Some(MergeType::Quote));
1250    }
1251}