1use std::{
2 collections::{HashSet, VecDeque},
3 sync::Arc,
4 time::Duration,
5};
6
7use longport_httpcli::HttpClient;
8use longport_proto::trade::{Sub, SubResponse, Unsub, UnsubResponse};
9use longport_wscli::{
10 CodecType, Platform, ProtocolVersion, WsClient, WsClientError, WsEvent, WsSession,
11};
12use tokio::{
13 sync::{mpsc, oneshot},
14 time::Instant,
15};
16
17use crate::{
18 Config, Result,
19 trade::{PushEvent, PushOrderChanged, TopicType, cmd_code},
20};
21
22const RECONNECT_DELAY: Duration = Duration::from_secs(2);
23
24pub(crate) enum Command {
25 Subscribe {
26 topics: Vec<TopicType>,
27 reply_tx: oneshot::Sender<Result<()>>,
28 },
29 Unsubscribe {
30 topics: Vec<TopicType>,
31 reply_tx: oneshot::Sender<Result<()>>,
32 },
33 SubmittedOrder {
34 order_id: String,
35 },
36}
37
38pub(crate) struct Core {
39 config: Arc<Config>,
40 command_rx: mpsc::UnboundedReceiver<Command>,
41 push_tx: mpsc::UnboundedSender<PushEvent>,
42 event_tx: mpsc::UnboundedSender<WsEvent>,
43 event_rx: mpsc::UnboundedReceiver<WsEvent>,
44 http_cli: HttpClient,
45 ws_cli: WsClient,
46 session: Option<WsSession>,
47 close: bool,
48 subscriptions: HashSet<String>,
49 unknown_orders: VecDeque<(Instant, PushOrderChanged)>,
50}
51
52impl Core {
53 pub(crate) async fn try_new(
54 config: Arc<Config>,
55 command_rx: mpsc::UnboundedReceiver<Command>,
56 push_tx: mpsc::UnboundedSender<PushEvent>,
57 ) -> Result<Self> {
58 let http_cli = config.create_http_client();
59 let otp = http_cli.get_otp_v2().await?;
60
61 let (event_tx, event_rx) = mpsc::unbounded_channel();
62
63 tracing::info!("connecting to trade server");
64 let (url, res) = config.create_trade_ws_request().await;
65 let request = res.map_err(WsClientError::from)?;
66 let ws_cli = WsClient::open(
67 request,
68 ProtocolVersion::Version1,
69 CodecType::Protobuf,
70 Platform::OpenAPI,
71 event_tx.clone(),
72 vec![],
73 )
74 .await?;
75
76 tracing::info!(url = url, "trade server connected");
77
78 let session = ws_cli.request_auth(otp, Default::default()).await?;
79
80 Ok(Self {
81 config,
82 command_rx,
83 push_tx,
84 event_tx,
85 event_rx,
86 http_cli,
87 ws_cli,
88 session: Some(session),
89 close: false,
90 subscriptions: HashSet::new(),
91 unknown_orders: VecDeque::new(),
92 })
93 }
94
95 pub(crate) async fn run(mut self) {
96 while !self.close {
97 match self.main_loop().await {
98 Ok(()) => return,
99 Err(err) => tracing::error!(error = %err, "trade disconnected"),
100 }
101
102 loop {
103 tokio::time::sleep(RECONNECT_DELAY).await;
105
106 tracing::info!("connecting to trade server");
107 let (url, res) = self.config.create_trade_ws_request().await;
108 let request = res.expect("BUG: failed to create trade ws request");
109
110 match WsClient::open(
111 request,
112 ProtocolVersion::Version1,
113 CodecType::Protobuf,
114 Platform::OpenAPI,
115 self.event_tx.clone(),
116 vec![],
117 )
118 .await
119 {
120 Ok(ws_cli) => self.ws_cli = ws_cli,
121 Err(err) => {
122 tracing::error!(error = %err, "failed to connect trade server");
123 continue;
124 }
125 }
126
127 tracing::info!(url = url, "trade server connected");
128
129 match &self.session {
131 Some(session) if !session.is_expired() => {
132 match self
133 .ws_cli
134 .request_reconnect(&session.session_id, Default::default())
135 .await
136 {
137 Ok(new_session) => self.session = Some(new_session),
138 Err(err) => {
139 self.session = None; tracing::error!(error = %err, "failed to request session id");
141 continue;
142 }
143 }
144 }
145 _ => {
146 let otp = match self.http_cli.get_otp_v2().await {
147 Ok(otp) => otp,
148 Err(err) => {
149 tracing::error!(error = %err, "failed to request otp");
150 continue;
151 }
152 };
153
154 match self.ws_cli.request_auth(otp, Default::default()).await {
155 Ok(new_session) => self.session = Some(new_session),
156 Err(err) => {
157 tracing::error!(error = %err, "failed to request session id");
158 continue;
159 }
160 }
161 }
162 }
163
164 match self.resubscribe().await {
166 Ok(()) => break,
167 Err(err) => {
168 tracing::error!(error = %err, "failed to subscribe topics");
169 continue;
170 }
171 }
172 }
173 }
174 }
175
176 async fn main_loop(&mut self) -> Result<()> {
177 let mut tick = tokio::time::interval(Duration::from_millis(500));
178
179 loop {
180 tokio::select! {
181 item = self.event_rx.recv() => {
182 match item {
183 Some(event) => self.handle_ws_event(event).await?,
184 None => unreachable!(),
185 }
186 }
187 item = self.command_rx.recv() => {
188 match item {
189 Some(command) => self.handle_command(command).await?,
190 None => {
191 self.close = true;
192 return Ok(());
193 }
194 }
195 }
196 now = tick.tick() => self.handle_tick(now),
197 }
198 }
199 }
200
201 async fn handle_ws_event(&mut self, event: WsEvent) -> Result<()> {
202 match event {
203 WsEvent::Error(err) => Err(err.into()),
204 WsEvent::Push { command_code, body } => self.handle_push(command_code, body).await,
205 }
206 }
207
208 async fn handle_push(&mut self, command_code: u8, body: Vec<u8>) -> Result<()> {
209 match PushEvent::parse(command_code, &body) {
210 Ok(Some(event)) => {
211 tracing::info!(event = ?event, "push event");
212 let _ = self.push_tx.send(event);
213 }
214 Ok(None) => {}
215 Err(err) => {
216 tracing::error!(error = %err, "failed to parse push message")
217 }
218 }
219 Ok(())
220 }
221
222 fn handle_tick(&mut self, now: Instant) {
223 while let Some((t, _)) = self.unknown_orders.front() {
224 if now - *t > Duration::from_secs(1) {
225 let (_, order_changed) = self.unknown_orders.pop_front().unwrap();
226 _ = self.push_tx.send(PushEvent::OrderChanged(order_changed));
227 } else {
228 break;
229 }
230 }
231 }
232
233 async fn handle_command(&mut self, command: Command) -> Result<()> {
234 match command {
235 Command::Subscribe { topics, reply_tx } => {
236 let res = self.handle_subscribe(topics).await;
237 let _ = reply_tx.send(res);
238 Ok(())
239 }
240 Command::Unsubscribe { topics, reply_tx } => {
241 let res = self.handle_unsubscribe(topics).await;
242 let _ = reply_tx.send(res);
243 Ok(())
244 }
245 Command::SubmittedOrder { order_id } => {
246 while let Some((idx, _)) = self
247 .unknown_orders
248 .iter()
249 .enumerate()
250 .find(|(_, (_, order_changed))| order_changed.order_id == order_id)
251 {
252 let Some((_, order_changed)) = self.unknown_orders.remove(idx) else {
253 unreachable!();
254 };
255 let _ = self.push_tx.send(PushEvent::OrderChanged(order_changed));
256 }
257 Ok(())
258 }
259 }
260 }
261
262 async fn handle_subscribe(&mut self, topics: Vec<TopicType>) -> Result<()> {
263 let req = Sub {
264 topics: topics.iter().map(ToString::to_string).collect(),
265 };
266 tracing::info!(topics = ?req.topics, "subscribing topics");
267 let resp: SubResponse = self.ws_cli.request(cmd_code::SUBSCRIBE, None, req).await?;
268 self.subscriptions = resp.current.into_iter().collect();
269 Ok(())
270 }
271
272 async fn handle_unsubscribe(&mut self, topics: Vec<TopicType>) -> Result<()> {
273 let req = Unsub {
274 topics: topics.iter().map(ToString::to_string).collect(),
275 };
276 tracing::info!(topics = ?req.topics, "unsubscribing topics");
277 let resp: UnsubResponse = self
278 .ws_cli
279 .request(cmd_code::UNSUBSCRIBE, None, req)
280 .await?;
281 self.subscriptions = resp.current.into_iter().collect();
282
283 Ok(())
284 }
285
286 async fn resubscribe(&mut self) -> Result<()> {
287 let req = Sub {
288 topics: self.subscriptions.iter().cloned().collect(),
289 };
290 let resp: SubResponse = self.ws_cli.request(cmd_code::SUBSCRIBE, None, req).await?;
291 self.subscriptions = resp.current.into_iter().collect();
292 Ok(())
293 }
294}