longport/quote/
core.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::Arc,
4};
5
6use comfy_table::Table;
7use itertools::Itertools;
8use longport_candlesticks::{TradeSessionType, UpdateAction};
9use longport_httpcli::HttpClient;
10use longport_proto::quote::{
11    self, AdjustType, MarketTradeDayRequest, MarketTradeDayResponse, MultiSecurityRequest, Period,
12    PushQuoteTag, SecurityCandlestickRequest, SecurityCandlestickResponse,
13    SecurityStaticInfoResponse, SubscribeRequest, UnsubscribeRequest,
14};
15use longport_wscli::{
16    CodecType, Platform, ProtocolVersion, RateLimit, WsClient, WsClientError, WsEvent, WsSession,
17};
18use time::{Date, OffsetDateTime};
19use tokio::{
20    sync::{mpsc, oneshot},
21    time::{Duration, Instant},
22};
23
24use crate::{
25    Config, Error, Market, Result,
26    config::PushCandlestickMode,
27    quote::{
28        Candlestick, PushCandlestick, PushEvent, PushEventDetail, PushQuote, PushTrades,
29        RealtimeQuote, SecurityBoard, SecurityBrokers, SecurityDepth, Subscription, Trade,
30        TradeSession, TradeSessions, cmd_code,
31        store::{Candlesticks, Store, TailCandlestick},
32        sub_flags::SubFlags,
33        types::QuotePackageDetail,
34        utils::{format_date, parse_date},
35    },
36};
37
38const RECONNECT_DELAY: Duration = Duration::from_secs(2);
39
40pub(crate) enum Command {
41    Request {
42        command_code: u8,
43        body: Vec<u8>,
44        reply_tx: oneshot::Sender<Result<Vec<u8>>>,
45    },
46    Subscribe {
47        symbols: Vec<String>,
48        sub_types: SubFlags,
49        is_first_push: bool,
50        reply_tx: oneshot::Sender<Result<()>>,
51    },
52    Unsubscribe {
53        symbols: Vec<String>,
54        sub_types: SubFlags,
55        reply_tx: oneshot::Sender<Result<()>>,
56    },
57    SubscribeCandlesticks {
58        symbol: String,
59        period: Period,
60        trade_sessions: TradeSessions,
61        reply_tx: oneshot::Sender<Result<Vec<Candlestick>>>,
62    },
63    UnsubscribeCandlesticks {
64        symbol: String,
65        period: Period,
66        reply_tx: oneshot::Sender<Result<()>>,
67    },
68    Subscriptions {
69        reply_tx: oneshot::Sender<Vec<Subscription>>,
70    },
71    GetRealtimeQuote {
72        symbols: Vec<String>,
73        reply_tx: oneshot::Sender<Vec<RealtimeQuote>>,
74    },
75    GetRealtimeDepth {
76        symbol: String,
77        reply_tx: oneshot::Sender<SecurityDepth>,
78    },
79    GetRealtimeTrade {
80        symbol: String,
81        count: usize,
82        reply_tx: oneshot::Sender<Vec<Trade>>,
83    },
84    GetRealtimeBrokers {
85        symbol: String,
86        reply_tx: oneshot::Sender<SecurityBrokers>,
87    },
88    GetRealtimeCandlesticks {
89        symbol: String,
90        period: Period,
91        count: usize,
92        reply_tx: oneshot::Sender<Vec<Candlestick>>,
93    },
94}
95
96#[derive(Debug, Default)]
97struct TradingDays {
98    normal_days: HashMap<Market, HashSet<Date>>,
99    half_days: HashMap<Market, HashSet<Date>>,
100}
101
102impl TradingDays {
103    #[inline]
104    fn half_days(&self, market: Market) -> Days<'_> {
105        Days(self.half_days.get(&market))
106    }
107}
108
109#[derive(Debug, Copy, Clone)]
110struct Days<'a>(Option<&'a HashSet<Date>>);
111
112impl longport_candlesticks::Days for Days<'_> {
113    #[inline]
114    fn contains(&self, date: Date) -> bool {
115        match self.0 {
116            Some(days) => days.contains(&date),
117            None => false,
118        }
119    }
120}
121
122#[derive(Debug)]
123pub(crate) struct MarketPackageDetail {
124    pub(crate) market: String,
125    pub(crate) packages: Vec<QuotePackageDetail>,
126    pub(crate) warning: String,
127}
128
129pub(crate) struct Core {
130    config: Arc<Config>,
131    rate_limit: Vec<(u8, RateLimit)>,
132    command_rx: mpsc::UnboundedReceiver<Command>,
133    push_tx: mpsc::UnboundedSender<PushEvent>,
134    event_tx: mpsc::UnboundedSender<WsEvent>,
135    event_rx: mpsc::UnboundedReceiver<WsEvent>,
136    http_cli: HttpClient,
137    ws_cli: WsClient,
138    session: Option<WsSession>,
139    close: bool,
140    subscriptions: HashMap<String, SubFlags>,
141    trading_days: TradingDays,
142    store: Store,
143    member_id: i64,
144    quote_level: String,
145    quote_package_details: Vec<QuotePackageDetail>,
146    push_candlestick_mode: PushCandlestickMode,
147}
148
149impl Core {
150    pub(crate) async fn try_new(
151        config: Arc<Config>,
152        command_rx: mpsc::UnboundedReceiver<Command>,
153        push_tx: mpsc::UnboundedSender<PushEvent>,
154    ) -> Result<Self> {
155        let http_cli = config.create_http_client();
156        let otp = http_cli.get_otp().await?;
157
158        let (event_tx, event_rx) = mpsc::unbounded_channel();
159
160        tracing::info!("connecting to quote server");
161        let (url, res) = config.create_quote_ws_request().await;
162        let request = res.map_err(WsClientError::from)?;
163
164        let mut ws_cli = WsClient::open(
165            request,
166            ProtocolVersion::Version1,
167            CodecType::Protobuf,
168            Platform::OpenAPI,
169            event_tx.clone(),
170            vec![],
171        )
172        .await?;
173
174        tracing::info!(url = url, "quote server connected");
175
176        let session = ws_cli.request_auth(otp, config.create_metadata()).await?;
177
178        // fetch user profile
179        let resp = ws_cli
180            .request::<_, quote::UserQuoteProfileResponse>(
181                cmd_code::QUERY_USER_QUOTE_PROFILE,
182                None,
183                quote::UserQuoteProfileRequest {
184                    language: config.language.to_string(),
185                },
186            )
187            .await?;
188        let member_id = resp.member_id;
189        let quote_level = resp.quote_level;
190        let (quote_package_details, quote_package_details_by_market) = resp
191            .quote_level_detail
192            .map(|details| {
193                Ok::<_, Error>((
194                    details
195                        .by_package_key
196                        .into_values()
197                        .map(TryInto::try_into)
198                        .collect::<Result<Vec<_>>>()?,
199                    details
200                        .by_market_code
201                        .into_iter()
202                        .map(|(market, market_packages)| {
203                            Ok(MarketPackageDetail {
204                                market,
205                                packages: market_packages
206                                    .packages
207                                    .into_iter()
208                                    .map(TryInto::try_into)
209                                    .collect::<Result<Vec<_>>>()?,
210                                warning: market_packages.warning_msg,
211                            })
212                        })
213                        .collect::<Result<Vec<_>>>()?,
214                ))
215            })
216            .transpose()?
217            .unwrap_or_default();
218        let rate_limit: Vec<(u8, RateLimit)> = resp
219            .rate_limit
220            .iter()
221            .map(|config| {
222                (
223                    config.command as u8,
224                    RateLimit {
225                        interval: Duration::from_secs(1),
226                        initial: config.burst as usize,
227                        max: config.burst as usize,
228                        refill: config.limit as usize,
229                    },
230                )
231            })
232            .collect();
233        ws_cli.set_rate_limit(rate_limit.clone());
234
235        let current_trade_days = fetch_trading_days(&ws_cli).await?;
236        let push_candlestick_mode = config.push_candlestick_mode.unwrap_or_default();
237
238        let mut table = Table::new();
239        for market_packages in quote_package_details_by_market {
240            if market_packages.warning.is_empty() {
241                table.add_row(vec![
242                    market_packages.market,
243                    market_packages
244                        .packages
245                        .into_iter()
246                        .map(|package| package.name)
247                        .join(", "),
248                ]);
249            } else {
250                table.add_row(vec![market_packages.market, market_packages.warning]);
251            }
252        }
253
254        if config.enable_print_quote_packages {
255            println!("{table}");
256        }
257
258        tracing::info!(
259            member_id = member_id,
260            quote_level = quote_level,
261            quote_package_details = ?quote_package_details,
262            "quote context initialized",
263        );
264
265        Ok(Self {
266            config,
267            rate_limit,
268            command_rx,
269            push_tx,
270            event_tx,
271            event_rx,
272            http_cli,
273            ws_cli,
274            session: Some(session),
275            close: false,
276            subscriptions: HashMap::new(),
277            trading_days: current_trade_days,
278            store: Store::default(),
279            member_id,
280            quote_level,
281            quote_package_details,
282            push_candlestick_mode,
283        })
284    }
285
286    #[inline]
287    pub(crate) fn member_id(&self) -> i64 {
288        self.member_id
289    }
290
291    #[inline]
292    pub(crate) fn quote_level(&self) -> &str {
293        &self.quote_level
294    }
295
296    #[inline]
297    pub(crate) fn quote_package_details(&self) -> &[QuotePackageDetail] {
298        &self.quote_package_details
299    }
300
301    pub(crate) async fn run(mut self) {
302        while !self.close {
303            match self.main_loop().await {
304                Ok(()) => return,
305                Err(err) => tracing::error!(error = %err, "quote disconnected"),
306            }
307
308            loop {
309                // reconnect
310                tokio::time::sleep(RECONNECT_DELAY).await;
311
312                tracing::info!("connecting to quote server");
313                let (url, res) = self.config.create_quote_ws_request().await;
314                let request = res.expect("BUG: failed to create quote ws request");
315
316                match WsClient::open(
317                    request,
318                    ProtocolVersion::Version1,
319                    CodecType::Protobuf,
320                    Platform::OpenAPI,
321                    self.event_tx.clone(),
322                    self.rate_limit.clone(),
323                )
324                .await
325                {
326                    Ok(ws_cli) => self.ws_cli = ws_cli,
327                    Err(err) => {
328                        tracing::error!(error = %err, "failed to connect quote server");
329                        continue;
330                    }
331                }
332
333                tracing::info!(url = url, "quote server connected");
334
335                // request new session
336                match &self.session {
337                    Some(session) if !session.is_expired() => {
338                        match self
339                            .ws_cli
340                            .request_reconnect(&session.session_id, self.config.create_metadata())
341                            .await
342                        {
343                            Ok(new_session) => self.session = Some(new_session),
344                            Err(err) => {
345                                self.session = None; // invalid session
346                                tracing::error!(error = %err, "failed to request session id");
347                                continue;
348                            }
349                        }
350                    }
351                    _ => {
352                        let otp = match self.http_cli.get_otp().await {
353                            Ok(otp) => otp,
354                            Err(err) => {
355                                tracing::error!(error = %err, "failed to request otp");
356                                continue;
357                            }
358                        };
359
360                        match self
361                            .ws_cli
362                            .request_auth(otp, self.config.create_metadata())
363                            .await
364                        {
365                            Ok(new_session) => self.session = Some(new_session),
366                            Err(err) => {
367                                tracing::error!(error = %err, "failed to request session id");
368                                continue;
369                            }
370                        }
371                    }
372                }
373
374                // handle reconnect
375                match self.resubscribe().await {
376                    Ok(()) => break,
377                    Err(err) => {
378                        tracing::error!(error = %err, "failed to subscribe topics");
379                        continue;
380                    }
381                }
382            }
383        }
384    }
385
386    async fn main_loop(&mut self) -> Result<()> {
387        let mut update_trading_days_interval = tokio::time::interval_at(
388            Instant::now() + Duration::from_secs(60 * 60 * 24),
389            Duration::from_secs(60 * 60 * 24),
390        );
391
392        loop {
393            tokio::select! {
394                item = self.event_rx.recv() => {
395                    match item {
396                        Some(event) => self.handle_ws_event(event).await?,
397                        None => unreachable!(),
398                    }
399                }
400                item = self.command_rx.recv() => {
401                    match item {
402                        Some(command) => self.handle_command(command).await?,
403                        None => {
404                            self.close = true;
405                            return Ok(());
406                        }
407                    }
408                }
409                _ = update_trading_days_interval.tick() => {
410                    if let Ok(days) = fetch_trading_days(&self.ws_cli).await {
411                        self.trading_days = days;
412                    }
413                }
414            }
415        }
416    }
417
418    async fn handle_command(&mut self, command: Command) -> Result<()> {
419        match command {
420            Command::Request {
421                command_code,
422                body,
423                reply_tx,
424            } => self.handle_request(command_code, body, reply_tx).await,
425            Command::Subscribe {
426                symbols,
427                sub_types,
428                is_first_push,
429                reply_tx,
430            } => {
431                let res = self
432                    .handle_subscribe(symbols, sub_types, is_first_push)
433                    .await;
434                let _ = reply_tx.send(res);
435                Ok(())
436            }
437            Command::Unsubscribe {
438                symbols,
439                sub_types,
440                reply_tx,
441            } => {
442                let _ = reply_tx.send(self.handle_unsubscribe(symbols, sub_types).await);
443                Ok(())
444            }
445            Command::SubscribeCandlesticks {
446                symbol,
447                period,
448                trade_sessions,
449                reply_tx,
450            } => {
451                let _ = reply_tx.send(
452                    self.handle_subscribe_candlesticks(symbol, period, trade_sessions)
453                        .await,
454                );
455                Ok(())
456            }
457            Command::UnsubscribeCandlesticks {
458                symbol,
459                period,
460                reply_tx,
461            } => {
462                let _ = reply_tx.send(self.handle_unsubscribe_candlesticks(symbol, period).await);
463                Ok(())
464            }
465            Command::Subscriptions { reply_tx } => {
466                let res = self.handle_subscriptions().await;
467                let _ = reply_tx.send(res);
468                Ok(())
469            }
470            Command::GetRealtimeQuote { symbols, reply_tx } => {
471                let _ = reply_tx.send(self.handle_get_realtime_quote(symbols));
472                Ok(())
473            }
474            Command::GetRealtimeDepth { symbol, reply_tx } => {
475                let _ = reply_tx.send(self.handle_get_realtime_depth(symbol));
476                Ok(())
477            }
478            Command::GetRealtimeTrade {
479                symbol,
480                count,
481                reply_tx,
482            } => {
483                let _ = reply_tx.send(self.handle_get_realtime_trades(symbol, count));
484                Ok(())
485            }
486            Command::GetRealtimeBrokers { symbol, reply_tx } => {
487                let _ = reply_tx.send(self.handle_get_realtime_brokers(symbol));
488                Ok(())
489            }
490            Command::GetRealtimeCandlesticks {
491                symbol,
492                period,
493                count,
494                reply_tx,
495            } => {
496                let _ = reply_tx.send(self.handle_get_realtime_candlesticks(symbol, period, count));
497                Ok(())
498            }
499        }
500    }
501
502    async fn handle_request(
503        &mut self,
504        command_code: u8,
505        body: Vec<u8>,
506        reply_tx: oneshot::Sender<Result<Vec<u8>>>,
507    ) -> Result<()> {
508        let res = self.ws_cli.request_raw(command_code, None, body).await;
509        let _ = reply_tx.send(res.map_err(Into::into));
510        Ok(())
511    }
512
513    async fn handle_subscribe(
514        &mut self,
515        symbols: Vec<String>,
516        sub_types: SubFlags,
517        is_first_push: bool,
518    ) -> Result<()> {
519        // send request
520        let req = SubscribeRequest {
521            symbol: symbols.clone(),
522            sub_type: sub_types.into(),
523            is_first_push,
524        };
525        self.ws_cli
526            .request::<_, ()>(cmd_code::SUBSCRIBE, None, req)
527            .await?;
528
529        // update subscriptions
530        for symbol in symbols {
531            self.subscriptions
532                .entry(symbol)
533                .and_modify(|flags| *flags |= sub_types)
534                .or_insert(sub_types);
535        }
536
537        Ok(())
538    }
539
540    async fn handle_unsubscribe(
541        &mut self,
542        symbols: Vec<String>,
543        sub_types: SubFlags,
544    ) -> Result<()> {
545        tracing::info!(symbols = ?symbols, sub_types = ?sub_types, "unsubscribe");
546
547        // send requests
548        let mut st_group: HashMap<SubFlags, Vec<&str>> = HashMap::new();
549
550        for symbol in &symbols {
551            let mut st = sub_types;
552
553            if let Some(candlesticks) = self
554                .store
555                .securities
556                .get(symbol)
557                .map(|data| &data.candlesticks)
558                && !candlesticks.is_empty()
559            {
560                st.remove(SubFlags::QUOTE | SubFlags::TRADE);
561            }
562
563            if !st.is_empty() {
564                st_group.entry(st).or_default().push(symbol.as_ref());
565            }
566        }
567
568        let requests = st_group
569            .iter()
570            .map(|(st, symbols)| UnsubscribeRequest {
571                symbol: symbols.iter().map(ToString::to_string).collect(),
572                sub_type: (*st).into(),
573                unsub_all: false,
574            })
575            .collect::<Vec<_>>();
576
577        for req in requests {
578            self.ws_cli
579                .request::<_, ()>(cmd_code::UNSUBSCRIBE, None, req)
580                .await?;
581        }
582
583        // update subscriptions
584        let mut remove_symbols = Vec::new();
585        for symbol in &symbols {
586            if let Some(cur_flags) = self.subscriptions.get_mut(symbol) {
587                *cur_flags &= !sub_types;
588                if cur_flags.is_empty() {
589                    remove_symbols.push(symbol);
590                }
591            }
592        }
593
594        for symbol in remove_symbols {
595            self.subscriptions.remove(symbol);
596        }
597        Ok(())
598    }
599
600    async fn handle_subscribe_candlesticks(
601        &mut self,
602        symbol: String,
603        period: Period,
604        trade_sessions: TradeSessions,
605    ) -> Result<Vec<Candlestick>> {
606        tracing::info!(symbol = symbol, period = ?period, "subscribe candlesticks");
607
608        if let Some(candlesticks) = self
609            .store
610            .securities
611            .get_mut(&symbol)
612            .and_then(|data| data.candlesticks.get_mut(&period))
613            .filter(|candlesticks| candlesticks.trade_sessions == trade_sessions)
614        {
615            candlesticks.trade_sessions = trade_sessions;
616            tracing::info!(symbol = symbol, period = ?period, trade_sessions = ?trade_sessions, "subscribed, returns candlesticks in memory");
617            return Ok(candlesticks.candlesticks.clone());
618        }
619
620        tracing::info!(symbol = symbol, "fetch symbol board");
621
622        let security_data = self.store.securities.entry(symbol.clone()).or_default();
623        if security_data.board == SecurityBoard::Unknown {
624            // update board
625            let resp: SecurityStaticInfoResponse = self
626                .ws_cli
627                .request(
628                    cmd_code::GET_BASIC_INFO,
629                    None,
630                    MultiSecurityRequest {
631                        symbol: vec![symbol.clone()],
632                    },
633                )
634                .await?;
635            if resp.secu_static_info.is_empty() {
636                return Err(Error::InvalidSecuritySymbol {
637                    symbol: symbol.clone(),
638                });
639            }
640            security_data.board = resp.secu_static_info[0].board.parse().unwrap_or_default();
641        }
642
643        tracing::info!(symbol = symbol, board = ?security_data.board, "got the symbol board");
644
645        // pull candlesticks
646        tracing::info!(symbol = symbol, period = ?period, "pull history candlesticks");
647        let resp: SecurityCandlestickResponse = self
648            .ws_cli
649            .request(
650                cmd_code::GET_SECURITY_CANDLESTICKS,
651                None,
652                SecurityCandlestickRequest {
653                    symbol: symbol.clone(),
654                    period: period.into(),
655                    count: 1000,
656                    adjust_type: AdjustType::NoAdjust.into(),
657                    trade_session: trade_sessions as i32,
658                },
659            )
660            .await?;
661        tracing::info!(symbol = symbol, period = ?period, len = resp.candlesticks.len(), "got history candlesticks");
662
663        let mut candlesticks = vec![];
664        let mut tails = HashMap::new();
665
666        for candlestick in resp.candlesticks {
667            let candlestick: Candlestick = candlestick.try_into()?;
668            let index = candlesticks.len();
669            candlesticks.push(candlestick);
670            tails.insert(
671                candlestick.trade_session,
672                TailCandlestick { index, candlestick },
673            );
674        }
675
676        tracing::info!(symbol = symbol, period = ?period, count = candlesticks.len(), tails = ?tails, "candlesticks loaded");
677
678        security_data
679            .candlesticks
680            .entry(period)
681            .or_insert_with(|| Candlesticks {
682                trade_sessions,
683                candlesticks: candlesticks.clone(),
684                tails,
685            });
686
687        // subscribe
688        if self
689            .subscriptions
690            .get(&symbol)
691            .copied()
692            .unwrap_or_else(SubFlags::empty)
693            .contains(SubFlags::QUOTE | SubFlags::TRADE)
694        {
695            return Ok(candlesticks);
696        }
697
698        tracing::info!(symbol = symbol, period = ?period, "subscribe quote for candlesticks");
699
700        let req = SubscribeRequest {
701            symbol: vec![symbol.clone()],
702            sub_type: (SubFlags::QUOTE | SubFlags::TRADE).into(),
703            is_first_push: true,
704        };
705        self.ws_cli
706            .request::<_, ()>(cmd_code::SUBSCRIBE, None, req)
707            .await?;
708
709        tracing::info!(symbol = symbol, period = ?period, "subscribed quote for candlesticks");
710        Ok(candlesticks)
711    }
712
713    async fn handle_unsubscribe_candlesticks(
714        &mut self,
715        symbol: String,
716        period: Period,
717    ) -> Result<()> {
718        if let Some(periods) = self
719            .store
720            .securities
721            .get_mut(&symbol)
722            .map(|data| &mut data.candlesticks)
723        {
724            periods.remove(&period);
725
726            let sub_flags = self
727                .subscriptions
728                .get(&symbol)
729                .copied()
730                .unwrap_or_else(SubFlags::empty);
731
732            if periods.is_empty() && !sub_flags.intersects(SubFlags::QUOTE | SubFlags::TRADE) {
733                tracing::info!(symbol = symbol, "unsubscribe quote for candlesticks");
734                self.ws_cli
735                    .request::<_, ()>(
736                        cmd_code::UNSUBSCRIBE,
737                        None,
738                        UnsubscribeRequest {
739                            symbol: vec![symbol],
740                            sub_type: (SubFlags::QUOTE | SubFlags::TRADE).into(),
741                            unsub_all: false,
742                        },
743                    )
744                    .await?;
745            }
746        }
747
748        Ok(())
749    }
750
751    async fn handle_subscriptions(&mut self) -> Vec<Subscription> {
752        let mut subscriptions = HashMap::new();
753
754        for (symbol, sub_flags) in &self.subscriptions {
755            if sub_flags.is_empty() {
756                continue;
757            }
758
759            subscriptions.insert(
760                symbol.clone(),
761                Subscription {
762                    symbol: symbol.clone(),
763                    sub_types: *sub_flags,
764                    candlesticks: vec![],
765                },
766            );
767        }
768
769        for (symbol, data) in &self.store.securities {
770            subscriptions
771                .entry(symbol.clone())
772                .or_insert_with(|| Subscription {
773                    symbol: symbol.clone(),
774                    sub_types: SubFlags::empty(),
775                    candlesticks: vec![],
776                })
777                .candlesticks = data.candlesticks.keys().copied().collect();
778        }
779
780        subscriptions.into_values().collect()
781    }
782
783    async fn handle_ws_event(&mut self, event: WsEvent) -> Result<()> {
784        match event {
785            WsEvent::Error(err) => Err(err.into()),
786            WsEvent::Push { command_code, body } => self.handle_push(command_code, body),
787        }
788    }
789
790    async fn resubscribe(&mut self) -> Result<()> {
791        let mut subscriptions: HashMap<SubFlags, HashSet<String>> = HashMap::new();
792
793        for (symbol, flags) in &self.subscriptions {
794            subscriptions
795                .entry(*flags)
796                .or_default()
797                .insert(symbol.clone());
798        }
799
800        for (symbol, data) in &self.store.securities {
801            if !data.candlesticks.is_empty() {
802                subscriptions
803                    .entry(SubFlags::QUOTE | SubFlags::TRADE)
804                    .or_default()
805                    .insert(symbol.clone());
806            }
807        }
808
809        tracing::info!(subscriptions = ?subscriptions, "resubscribe");
810
811        for (flags, symbols) in subscriptions {
812            self.ws_cli
813                .request::<_, ()>(
814                    cmd_code::SUBSCRIBE,
815                    None,
816                    SubscribeRequest {
817                        symbol: symbols.into_iter().collect(),
818                        sub_type: flags.into(),
819                        is_first_push: false,
820                    },
821                )
822                .await?;
823        }
824        Ok(())
825    }
826
827    fn merge_candlesticks_by_quote(&mut self, symbol: &str, push_quote: &PushQuote) {
828        let Some(market_type) = parse_market_from_symbol(symbol) else {
829            return;
830        };
831        let Some(security_data) = self.store.securities.get_mut(symbol) else {
832            return;
833        };
834
835        if push_quote.trade_session != TradeSession::Intraday {
836            return;
837        }
838
839        let half_days = self.trading_days.half_days(market_type);
840
841        for (period, candlesticks) in &mut security_data.candlesticks {
842            let Some(mtype) = merge_type(security_data.board, push_quote.trade_session, *period)
843            else {
844                continue;
845            };
846
847            let action = if mtype == MergeType::QuoteDay {
848                Some(candlesticks.merge_quote_day(market_type, security_data.board, push_quote))
849            } else if mtype == MergeType::Quote {
850                Some(candlesticks.merge_quote(
851                    market_type,
852                    half_days,
853                    security_data.board,
854                    *period,
855                    push_quote,
856                ))
857            } else {
858                None
859            };
860
861            if let Some(action) = action {
862                update_and_push_candlestick(
863                    candlesticks,
864                    push_quote.trade_session,
865                    symbol,
866                    *period,
867                    action,
868                    self.push_candlestick_mode,
869                    &mut self.push_tx,
870                );
871            }
872        }
873    }
874
875    fn merge_candlesticks_by_trades(&mut self, symbol: &str, push_trades: &PushTrades) {
876        let Some(market_type) = parse_market_from_symbol(symbol) else {
877            return;
878        };
879        let Some(security_data) = self.store.securities.get_mut(symbol) else {
880            return;
881        };
882
883        let half_days = self.trading_days.half_days(market_type);
884
885        for trade in &push_trades.trades {
886            for (period, candlesticks) in &mut security_data.candlesticks {
887                if merge_type(security_data.board, trade.trade_session, *period)
888                    != Some(MergeType::Trade)
889                {
890                    continue;
891                }
892
893                let action = candlesticks.merge_trade(
894                    market_type,
895                    half_days,
896                    security_data.board,
897                    *period,
898                    trade,
899                );
900                update_and_push_candlestick(
901                    candlesticks,
902                    trade.trade_session,
903                    symbol,
904                    *period,
905                    action,
906                    self.push_candlestick_mode,
907                    &mut self.push_tx,
908                );
909            }
910        }
911    }
912
913    fn handle_push(&mut self, command_code: u8, body: Vec<u8>) -> Result<()> {
914        match PushEvent::parse(command_code, &body) {
915            Ok((mut event, tag)) => {
916                tracing::info!(event = ?event, tag = ?tag, "push event");
917
918                if tag != Some(PushQuoteTag::Eod) {
919                    self.store.handle_push(&mut event);
920                }
921
922                if let PushEventDetail::Quote(push_quote) = &event.detail {
923                    self.merge_candlesticks_by_quote(&event.symbol, push_quote);
924
925                    if !self
926                        .subscriptions
927                        .get(&event.symbol)
928                        .map(|sub_flags| sub_flags.contains(SubFlags::QUOTE))
929                        .unwrap_or_default()
930                    {
931                        return Ok(());
932                    }
933                } else if let PushEventDetail::Trade(trades) = &event.detail {
934                    self.merge_candlesticks_by_trades(&event.symbol, trades);
935
936                    if !self
937                        .subscriptions
938                        .get(&event.symbol)
939                        .map(|sub_flags| sub_flags.contains(SubFlags::TRADE))
940                        .unwrap_or_default()
941                    {
942                        return Ok(());
943                    }
944                }
945
946                if tag == Some(PushQuoteTag::Eod) {
947                    return Ok(());
948                }
949
950                let _ = self.push_tx.send(event);
951            }
952            Err(err) => {
953                tracing::error!(error = %err, "failed to parse push message");
954            }
955        }
956        Ok(())
957    }
958
959    fn handle_get_realtime_quote(&self, symbols: Vec<String>) -> Vec<RealtimeQuote> {
960        let mut result = Vec::new();
961
962        for symbol in symbols {
963            if let Some(data) = self.store.securities.get(&symbol) {
964                result.push(RealtimeQuote {
965                    symbol,
966                    last_done: data.quote.last_done,
967                    open: data.quote.open,
968                    high: data.quote.high,
969                    low: data.quote.low,
970                    timestamp: data.quote.timestamp,
971                    volume: data.quote.volume,
972                    turnover: data.quote.turnover,
973                    trade_status: data.quote.trade_status,
974                });
975            }
976        }
977
978        result
979    }
980
981    fn handle_get_realtime_depth(&self, symbol: String) -> SecurityDepth {
982        let mut result = SecurityDepth::default();
983        if let Some(data) = self.store.securities.get(&symbol) {
984            result.asks.clone_from(&data.asks);
985            result.bids.clone_from(&data.bids);
986        }
987        result
988    }
989
990    fn handle_get_realtime_trades(&self, symbol: String, count: usize) -> Vec<Trade> {
991        let mut res = Vec::new();
992
993        if let Some(data) = self.store.securities.get(&symbol) {
994            let trades = if data.trades.len() >= count {
995                &data.trades[data.trades.len() - count..]
996            } else {
997                &data.trades
998            };
999            res = trades.to_vec();
1000        }
1001        res
1002    }
1003
1004    fn handle_get_realtime_brokers(&self, symbol: String) -> SecurityBrokers {
1005        let mut result = SecurityBrokers::default();
1006        if let Some(data) = self.store.securities.get(&symbol) {
1007            result.ask_brokers.clone_from(&data.ask_brokers);
1008            result.bid_brokers.clone_from(&data.bid_brokers);
1009        }
1010        result
1011    }
1012
1013    fn handle_get_realtime_candlesticks(
1014        &self,
1015        symbol: String,
1016        period: Period,
1017        count: usize,
1018    ) -> Vec<Candlestick> {
1019        self.store
1020            .securities
1021            .get(&symbol)
1022            .map(|data| &data.candlesticks)
1023            .and_then(|periods| periods.get(&period))
1024            .map(|candlesticks| {
1025                let candlesticks = if candlesticks.candlesticks.len() >= count {
1026                    &candlesticks.candlesticks[candlesticks.candlesticks.len() - count..]
1027                } else {
1028                    &candlesticks.candlesticks
1029                };
1030                candlesticks.to_vec()
1031            })
1032            .unwrap_or_default()
1033    }
1034}
1035
1036#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1037enum MergeType {
1038    Trade,
1039    QuoteDay,
1040    Quote,
1041}
1042
1043fn merge_type(
1044    board: SecurityBoard,
1045    trade_session: TradeSession,
1046    period: Period,
1047) -> Option<MergeType> {
1048    use Period::*;
1049    use SecurityBoard::*;
1050    use TradeSession::*;
1051
1052    if !trade_session.is_intraday() && period >= Day {
1053        return None;
1054    }
1055
1056    Some(match (board, trade_session, period) {
1057        (USDJI | USNSDQ | USSector | HKHS | HKSector | CNIX | CNSector | STI | SGSector, _, _) => {
1058            if period == Day && trade_session == Intraday {
1059                MergeType::QuoteDay
1060            } else {
1061                MergeType::Quote
1062            }
1063        }
1064        (_, _, Day) if trade_session == Intraday => MergeType::QuoteDay,
1065        _ => MergeType::Trade,
1066    })
1067}
1068
1069async fn fetch_trading_days(cli: &WsClient) -> Result<TradingDays> {
1070    let mut days = TradingDays::default();
1071    let begin_day = OffsetDateTime::now_utc().date() - time::Duration::days(5);
1072    let end_day = begin_day + time::Duration::days(30);
1073
1074    for market in [Market::HK, Market::US, Market::SG, Market::CN] {
1075        let resp = cli
1076            .request::<_, MarketTradeDayResponse>(
1077                cmd_code::GET_TRADING_DAYS,
1078                None,
1079                MarketTradeDayRequest {
1080                    market: market.to_string(),
1081                    beg_day: format_date(begin_day),
1082                    end_day: format_date(end_day),
1083                },
1084            )
1085            .await?;
1086
1087        days.normal_days.insert(
1088            market,
1089            resp.trade_day
1090                .iter()
1091                .map(|value| {
1092                    parse_date(value).map_err(|err| Error::parse_field_error("half_trade_day", err))
1093                })
1094                .collect::<Result<HashSet<_>>>()?,
1095        );
1096
1097        days.half_days.insert(
1098            market,
1099            resp.half_trade_day
1100                .iter()
1101                .map(|value| {
1102                    parse_date(value).map_err(|err| Error::parse_field_error("half_trade_day", err))
1103                })
1104                .collect::<Result<HashSet<_>>>()?,
1105        );
1106    }
1107
1108    Ok(days)
1109}
1110
1111#[allow(clippy::too_many_arguments)]
1112fn update_and_push_candlestick(
1113    candlesticks: &mut Candlesticks,
1114    ts: TradeSession,
1115    symbol: &str,
1116    period: Period,
1117    action: UpdateAction<Candlestick>,
1118    push_candlestick_mode: PushCandlestickMode,
1119    tx: &mut mpsc::UnboundedSender<PushEvent>,
1120) {
1121    let mut push_candlesticks = Vec::new();
1122
1123    match action {
1124        UpdateAction::UpdateLast(candlestick) => {
1125            let tail = candlesticks.tails.get_mut(&ts).unwrap();
1126            candlesticks.candlesticks[tail.index] = candlestick;
1127            tail.candlestick = candlestick;
1128
1129            if push_candlestick_mode == PushCandlestickMode::Realtime {
1130                push_candlesticks.push((candlestick, false));
1131            }
1132        }
1133        UpdateAction::AppendNew { confirmed, new } => {
1134            let index = if let Some(tail) = candlesticks.tails.get_mut(&ts) {
1135                candlesticks.candlesticks.insert(tail.index + 1, new);
1136                tail.index += 1;
1137                tail.candlestick = new;
1138                tail.index
1139            } else {
1140                let index = candlesticks.insert_candlestick_by_time(new);
1141                candlesticks.tails.insert(
1142                    ts,
1143                    TailCandlestick {
1144                        index,
1145                        candlestick: new,
1146                    },
1147                );
1148                index
1149            };
1150
1151            for tail in candlesticks.tails.values_mut() {
1152                if tail.index > index {
1153                    tail.index += 1;
1154                }
1155            }
1156
1157            candlesticks.check_and_remove();
1158
1159            match push_candlestick_mode {
1160                PushCandlestickMode::Realtime => {
1161                    if let Some(confirmed) = confirmed {
1162                        push_candlesticks.push((confirmed, true));
1163                    }
1164                    push_candlesticks.push((new, false));
1165                }
1166                PushCandlestickMode::Confirmed => {
1167                    if let Some(confirmed) = confirmed {
1168                        push_candlesticks.push((confirmed, true));
1169                    }
1170                }
1171            }
1172        }
1173        UpdateAction::None => {}
1174    };
1175
1176    for (candlestick, is_confirmed) in push_candlesticks {
1177        if candlesticks.trade_sessions.contains(ts) {
1178            tracing::info!(
1179                symbol = symbol,
1180                period = ?period,
1181                is_confirmed = is_confirmed,
1182                candlestick = ?candlestick,
1183                trade_session = ?ts,
1184                "push candlestick"
1185            );
1186            let _ = tx.send(PushEvent {
1187                sequence: 0,
1188                symbol: symbol.to_string(),
1189                detail: PushEventDetail::Candlestick(PushCandlestick {
1190                    period,
1191                    candlestick,
1192                    is_confirmed,
1193                }),
1194            });
1195        }
1196    }
1197}
1198
1199fn parse_market_from_symbol(symbol: &str) -> Option<Market> {
1200    let market = symbol.rfind('.').map(|idx| &symbol[idx + 1..])?;
1201    Some(match market {
1202        "US" => Market::US,
1203        "HK" => Market::HK,
1204        "SG" => Market::SG,
1205        "SH" | "SZ" => Market::CN,
1206        _ => return None,
1207    })
1208}
1209
1210#[cfg(test)]
1211mod tests {
1212    use super::*;
1213
1214    #[test]
1215    fn test_parse_market_from_symbol() {
1216        assert_eq!(parse_market_from_symbol("AAPL.US"), Some(Market::US));
1217        assert_eq!(parse_market_from_symbol("BRK.A.US"), Some(Market::US));
1218    }
1219
1220    #[test]
1221    fn test_merge_type() {
1222        use Period::*;
1223        use SecurityBoard::*;
1224        use TradeSession::*;
1225
1226        assert_eq!(merge_type(USDJI, Intraday, Day), Some(MergeType::QuoteDay));
1227        assert_eq!(merge_type(USDJI, Overnight, Day), None);
1228        assert_eq!(
1229            merge_type(USDJI, Intraday, OneMinute),
1230            Some(MergeType::Quote)
1231        );
1232        assert_eq!(
1233            merge_type(USDJI, Overnight, OneMinute),
1234            Some(MergeType::Quote)
1235        );
1236        assert_eq!(merge_type(USDJI, Intraday, Week), Some(MergeType::Quote));
1237        assert_eq!(merge_type(USDJI, Overnight, Week), None);
1238
1239        assert_eq!(merge_type(USMain, Intraday, Day), Some(MergeType::QuoteDay));
1240        assert_eq!(merge_type(USMain, Overnight, Day), None);
1241        assert_eq!(
1242            merge_type(USMain, Intraday, OneMinute),
1243            Some(MergeType::Trade)
1244        );
1245        assert_eq!(
1246            merge_type(USMain, Overnight, OneMinute),
1247            Some(MergeType::Trade)
1248        );
1249        assert_eq!(merge_type(USMain, Intraday, Week), Some(MergeType::Trade));
1250        assert_eq!(merge_type(USMain, Overnight, Week), None);
1251    }
1252}