1use std::{
2 future::Future,
3 io, mem,
4 num::NonZeroUsize,
5 pin::Pin,
6 rc::Rc,
7 sync::{
8 atomic::{AtomicUsize, Ordering},
9 Arc,
10 },
11 task::{Context, Poll},
12 time::Duration,
13};
14
15use actix_rt::{
16 spawn,
17 time::{sleep, Instant, Sleep},
18 Arbiter, ArbiterHandle, System,
19};
20use futures_core::{future::LocalBoxFuture, ready};
21use tokio::sync::{
22 mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
23 oneshot,
24};
25use tracing::{error, info, trace};
26
27use crate::{
28 service::{BoxedServerService, InternalServiceFactory},
29 socket::MioStream,
30 waker_queue::{WakerInterest, WakerQueue},
31};
32
33pub(crate) struct Stop {
36 graceful: bool,
37 tx: oneshot::Sender<bool>,
38}
39
40#[derive(Debug)]
41pub(crate) struct Conn {
42 pub io: MioStream,
43 pub token: usize,
44}
45
46fn handle_pair(
48 idx: usize,
49 conn_tx: UnboundedSender<Conn>,
50 stop_tx: UnboundedSender<Stop>,
51 counter: Counter,
52) -> (WorkerHandleAccept, WorkerHandleServer) {
53 let accept = WorkerHandleAccept {
54 idx,
55 conn_tx,
56 counter,
57 };
58
59 let server = WorkerHandleServer { idx, stop_tx };
60
61 (accept, server)
62}
63
64#[derive(Clone)]
82pub(crate) struct Counter {
83 counter: Arc<AtomicUsize>,
84 limit: usize,
85}
86
87impl Counter {
88 pub(crate) fn new(limit: usize) -> Self {
89 Self {
90 counter: Arc::new(AtomicUsize::new(1)),
91 limit,
92 }
93 }
94
95 #[inline(always)]
97 pub(crate) fn inc(&self) -> bool {
98 self.counter.fetch_add(1, Ordering::Relaxed) != self.limit
99 }
100
101 #[inline(always)]
103 pub(crate) fn dec(&self) -> bool {
104 self.counter.fetch_sub(1, Ordering::Relaxed) == self.limit
105 }
106
107 pub(crate) fn total(&self) -> usize {
108 self.counter.load(Ordering::SeqCst) - 1
109 }
110}
111
112pub(crate) struct WorkerCounter {
113 idx: usize,
114 inner: Rc<(WakerQueue, Counter)>,
115}
116
117impl Clone for WorkerCounter {
118 fn clone(&self) -> Self {
119 Self {
120 idx: self.idx,
121 inner: self.inner.clone(),
122 }
123 }
124}
125
126impl WorkerCounter {
127 pub(crate) fn new(idx: usize, waker_queue: WakerQueue, counter: Counter) -> Self {
128 Self {
129 idx,
130 inner: Rc::new((waker_queue, counter)),
131 }
132 }
133
134 #[inline(always)]
135 pub(crate) fn guard(&self) -> WorkerCounterGuard {
136 WorkerCounterGuard(self.clone())
137 }
138
139 fn total(&self) -> usize {
140 self.inner.1.total()
141 }
142}
143
144pub(crate) struct WorkerCounterGuard(WorkerCounter);
145
146impl Drop for WorkerCounterGuard {
147 fn drop(&mut self) {
148 let (waker_queue, counter) = &*self.0.inner;
149 if counter.dec() {
150 waker_queue.wake(WakerInterest::WorkerAvailable(self.0.idx));
151 }
152 }
153}
154
155pub(crate) struct WorkerHandleAccept {
160 idx: usize,
161 conn_tx: UnboundedSender<Conn>,
162 counter: Counter,
163}
164
165impl WorkerHandleAccept {
166 #[inline(always)]
167 pub(crate) fn idx(&self) -> usize {
168 self.idx
169 }
170
171 #[inline(always)]
172 pub(crate) fn send(&self, conn: Conn) -> Result<(), Conn> {
173 self.conn_tx.send(conn).map_err(|msg| msg.0)
174 }
175
176 #[inline(always)]
177 pub(crate) fn inc_counter(&self) -> bool {
178 self.counter.inc()
179 }
180}
181
182#[derive(Debug)]
186pub(crate) struct WorkerHandleServer {
187 pub(crate) idx: usize,
188 stop_tx: UnboundedSender<Stop>,
189}
190
191impl WorkerHandleServer {
192 pub(crate) fn stop(&self, graceful: bool) -> oneshot::Receiver<bool> {
193 let (tx, rx) = oneshot::channel();
194 let _ = self.stop_tx.send(Stop { graceful, tx });
195 rx
196 }
197}
198
199pub(crate) struct ServerWorker {
203 conn_rx: UnboundedReceiver<Conn>,
206 stop_rx: UnboundedReceiver<Stop>,
207 counter: WorkerCounter,
208 services: Box<[WorkerService]>,
209 factories: Box<[Box<dyn InternalServiceFactory>]>,
210 state: WorkerState,
211 shutdown_timeout: Duration,
212}
213
214struct WorkerService {
215 factory_idx: usize,
216 status: WorkerServiceStatus,
217 service: BoxedServerService,
218}
219
220impl WorkerService {
221 fn created(&mut self, service: BoxedServerService) {
222 self.service = service;
223 self.status = WorkerServiceStatus::Unavailable;
224 }
225}
226
227#[derive(Debug, Clone, Copy, PartialEq, Eq)]
228enum WorkerServiceStatus {
229 Available,
230 Unavailable,
231 Failed,
232 Restarting,
233 Stopping,
234 Stopped,
235}
236
237impl Default for WorkerServiceStatus {
238 fn default() -> Self {
239 Self::Unavailable
240 }
241}
242
243#[derive(Debug, Clone, Copy)]
245pub(crate) struct ServerWorkerConfig {
246 shutdown_timeout: Duration,
247 max_blocking_threads: usize,
248 max_concurrent_connections: usize,
249}
250
251impl Default for ServerWorkerConfig {
252 fn default() -> Self {
253 let parallelism = std::thread::available_parallelism().map_or(2, NonZeroUsize::get);
254
255 let max_blocking_threads = std::cmp::max(512 / parallelism, 1);
257
258 Self {
259 shutdown_timeout: Duration::from_secs(30),
260 max_blocking_threads,
261 max_concurrent_connections: 25600,
262 }
263 }
264}
265
266impl ServerWorkerConfig {
267 pub(crate) fn max_blocking_threads(&mut self, num: usize) {
268 self.max_blocking_threads = num;
269 }
270
271 pub(crate) fn max_concurrent_connections(&mut self, num: usize) {
272 self.max_concurrent_connections = num;
273 }
274
275 pub(crate) fn shutdown_timeout(&mut self, dur: Duration) {
276 self.shutdown_timeout = dur;
277 }
278}
279
280impl ServerWorker {
281 pub(crate) fn start(
282 idx: usize,
283 factories: Vec<Box<dyn InternalServiceFactory>>,
284 waker_queue: WakerQueue,
285 config: ServerWorkerConfig,
286 ) -> io::Result<(WorkerHandleAccept, WorkerHandleServer)> {
287 trace!("starting server worker {}", idx);
288
289 let (tx1, conn_rx) = unbounded_channel();
290 let (tx2, stop_rx) = unbounded_channel();
291
292 let counter = Counter::new(config.max_concurrent_connections);
293 let pair = handle_pair(idx, tx1, tx2, counter.clone());
294
295 let actix_system = System::try_current();
297
298 let tokio_handle = tokio::runtime::Handle::try_current().ok();
300
301 let (factory_tx, factory_rx) = std::sync::mpsc::sync_channel::<io::Result<()>>(1);
303
304 match (actix_system, tokio_handle) {
321 (None, None) => {
322 panic!("No runtime detected. Start a Tokio (or Actix) runtime.");
323 }
324
325 (None, Some(rt_handle)) => {
327 std::thread::Builder::new()
328 .name(format!("actix-server worker {}", idx))
329 .spawn(move || {
330 let (worker_stopped_tx, worker_stopped_rx) = oneshot::channel();
331
332 let ls = tokio::task::LocalSet::new();
334
335 let services = rt_handle.block_on(ls.run_until(async {
337 let mut services = Vec::new();
338
339 for (idx, factory) in factories.iter().enumerate() {
340 match factory.create().await {
341 Ok((token, svc)) => services.push((idx, token, svc)),
342
343 Err(err) => {
344 error!("can not start worker: {err:?}");
345 return Err(io::Error::other(format!(
346 "can not start server service {idx}",
347 )));
348 }
349 }
350 }
351
352 Ok(services)
353 }));
354
355 let services = match services {
356 Ok(services) => {
357 factory_tx.send(Ok(())).unwrap();
358 services
359 }
360 Err(err) => {
361 factory_tx.send(Err(err)).unwrap();
362 return;
363 }
364 };
365
366 let worker_services = wrap_worker_services(services);
367
368 let worker_fut = async move {
369 spawn(async move {
371 ServerWorker {
372 conn_rx,
373 stop_rx,
374 services: worker_services.into_boxed_slice(),
375 counter: WorkerCounter::new(idx, waker_queue, counter),
376 factories: factories.into_boxed_slice(),
377 state: WorkerState::default(),
378 shutdown_timeout: config.shutdown_timeout,
379 }
380 .await;
381
382 worker_stopped_tx.send(()).unwrap();
384 });
385
386 worker_stopped_rx.await.unwrap();
387 };
388
389 #[cfg(all(target_os = "linux", feature = "io-uring"))]
390 {
391 let _ = config.max_blocking_threads;
394 tokio_uring::start(worker_fut);
395 }
396
397 #[cfg(not(all(target_os = "linux", feature = "io-uring")))]
398 {
399 let rt = tokio::runtime::Builder::new_current_thread()
400 .enable_all()
401 .max_blocking_threads(config.max_blocking_threads)
402 .build()
403 .unwrap();
404
405 rt.block_on(ls.run_until(worker_fut));
406 }
407 })
408 .expect("cannot spawn server worker thread");
409 }
410
411 (Some(_sys), _) => {
413 #[cfg(all(target_os = "linux", feature = "io-uring"))]
414 let arbiter = {
415 let _ = config.max_blocking_threads;
418 Arbiter::new()
419 };
420
421 #[cfg(not(all(target_os = "linux", feature = "io-uring")))]
422 let arbiter = {
423 Arbiter::with_tokio_rt(move || {
424 tokio::runtime::Builder::new_current_thread()
425 .enable_all()
426 .max_blocking_threads(config.max_blocking_threads)
427 .build()
428 .unwrap()
429 })
430 };
431
432 arbiter.spawn(async move {
433 spawn(async move {
435 let mut services = Vec::new();
436
437 for (idx, factory) in factories.iter().enumerate() {
438 match factory.create().await {
439 Ok((token, svc)) => services.push((idx, token, svc)),
440
441 Err(err) => {
442 error!("can not start worker: {err:?}");
443 Arbiter::current().stop();
444 factory_tx
445 .send(Err(io::Error::other(format!(
446 "can not start server service {idx}",
447 ))))
448 .unwrap();
449 return;
450 }
451 }
452 }
453
454 factory_tx.send(Ok(())).unwrap();
455
456 let worker_services = wrap_worker_services(services);
457
458 spawn(ServerWorker {
460 conn_rx,
461 stop_rx,
462 services: worker_services.into_boxed_slice(),
463 counter: WorkerCounter::new(idx, waker_queue, counter),
464 factories: factories.into_boxed_slice(),
465 state: Default::default(),
466 shutdown_timeout: config.shutdown_timeout,
467 });
468 });
469 });
470 }
471 };
472
473 factory_rx.recv().unwrap()?;
475
476 Ok(pair)
477 }
478
479 fn restart_service(&mut self, idx: usize, factory_id: usize) {
480 let factory = &self.factories[factory_id];
481 trace!("service {:?} failed, restarting", factory.name(idx));
482 self.services[idx].status = WorkerServiceStatus::Restarting;
483 self.state = WorkerState::Restarting(Restart {
484 factory_id,
485 token: idx,
486 fut: factory.create(),
487 });
488 }
489
490 fn shutdown(&mut self, force: bool) {
491 self.services
492 .iter_mut()
493 .filter(|srv| srv.status == WorkerServiceStatus::Available)
494 .for_each(|srv| {
495 srv.status = if force {
496 WorkerServiceStatus::Stopped
497 } else {
498 WorkerServiceStatus::Stopping
499 };
500 });
501 }
502
503 fn check_readiness(&mut self, cx: &mut Context<'_>) -> Result<bool, (usize, usize)> {
504 let mut ready = true;
505 for (idx, srv) in self.services.iter_mut().enumerate() {
506 if srv.status == WorkerServiceStatus::Available
507 || srv.status == WorkerServiceStatus::Unavailable
508 {
509 match srv.service.poll_ready(cx) {
510 Poll::Ready(Ok(_)) => {
511 if srv.status == WorkerServiceStatus::Unavailable {
512 trace!(
513 "service {:?} is available",
514 self.factories[srv.factory_idx].name(idx)
515 );
516 srv.status = WorkerServiceStatus::Available;
517 }
518 }
519 Poll::Pending => {
520 ready = false;
521
522 if srv.status == WorkerServiceStatus::Available {
523 trace!(
524 "service {:?} is unavailable",
525 self.factories[srv.factory_idx].name(idx)
526 );
527 srv.status = WorkerServiceStatus::Unavailable;
528 }
529 }
530 Poll::Ready(Err(_)) => {
531 error!(
532 "service {:?} readiness check returned error, restarting",
533 self.factories[srv.factory_idx].name(idx)
534 );
535 srv.status = WorkerServiceStatus::Failed;
536 return Err((idx, srv.factory_idx));
537 }
538 }
539 }
540 }
541
542 Ok(ready)
543 }
544}
545
546enum WorkerState {
547 Available,
548 Unavailable,
549 Restarting(Restart),
550 Shutdown(Shutdown),
551}
552
553struct Restart {
554 factory_id: usize,
555 token: usize,
556 fut: LocalBoxFuture<'static, Result<(usize, BoxedServerService), ()>>,
557}
558
559struct Shutdown {
561 timer: Pin<Box<Sleep>>,
563
564 start_from: Instant,
566
567 tx: oneshot::Sender<bool>,
569}
570
571impl Default for WorkerState {
572 fn default() -> Self {
573 Self::Unavailable
574 }
575}
576
577impl Drop for ServerWorker {
578 fn drop(&mut self) {
579 Arbiter::try_current().as_ref().map(ArbiterHandle::stop);
580 }
581}
582
583impl Future for ServerWorker {
584 type Output = ();
585
586 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
587 let this = self.as_mut().get_mut();
588
589 if let Poll::Ready(Some(Stop { graceful, tx })) = this.stop_rx.poll_recv(cx) {
591 let num = this.counter.total();
592 if num == 0 {
593 info!("shutting down idle worker");
594 let _ = tx.send(true);
595 return Poll::Ready(());
596 } else if graceful {
597 info!("graceful worker shutdown; finishing {} connections", num);
598 this.shutdown(false);
599
600 this.state = WorkerState::Shutdown(Shutdown {
601 timer: Box::pin(sleep(Duration::from_secs(1))),
602 start_from: Instant::now(),
603 tx,
604 });
605 } else {
606 info!("force shutdown worker, closing {} connections", num);
607 this.shutdown(true);
608
609 let _ = tx.send(false);
610 return Poll::Ready(());
611 }
612 }
613
614 match this.state {
615 WorkerState::Unavailable => match this.check_readiness(cx) {
616 Ok(true) => {
617 this.state = WorkerState::Available;
618 self.poll(cx)
619 }
620 Ok(false) => Poll::Pending,
621 Err((token, idx)) => {
622 this.restart_service(token, idx);
623 self.poll(cx)
624 }
625 },
626
627 WorkerState::Restarting(ref mut restart) => {
628 let factory_id = restart.factory_id;
629 let token = restart.token;
630
631 let (token_new, service) =
632 ready!(restart.fut.as_mut().poll(cx)).unwrap_or_else(|_| {
633 panic!(
634 "Can not restart {:?} service",
635 this.factories[factory_id].name(token)
636 )
637 });
638
639 assert_eq!(token, token_new);
640
641 trace!(
642 "service {:?} has been restarted",
643 this.factories[factory_id].name(token)
644 );
645
646 this.services[token].created(service);
647 this.state = WorkerState::Unavailable;
648
649 self.poll(cx)
650 }
651
652 WorkerState::Shutdown(ref mut shutdown) => {
653 while let Poll::Ready(Some(conn)) = this.conn_rx.poll_recv(cx) {
655 let guard = this.counter.guard();
658 drop((conn, guard));
659 }
660
661 ready!(shutdown.timer.as_mut().poll(cx));
663
664 if this.counter.total() == 0 {
665 if let WorkerState::Shutdown(shutdown) = mem::take(&mut this.state) {
667 let _ = shutdown.tx.send(true);
668 }
669
670 Poll::Ready(())
671 } else if shutdown.start_from.elapsed() >= this.shutdown_timeout {
672 if let WorkerState::Shutdown(shutdown) = mem::take(&mut this.state) {
674 let _ = shutdown.tx.send(false);
675 }
676
677 Poll::Ready(())
678 } else {
679 let time = Instant::now() + Duration::from_secs(1);
681 shutdown.timer.as_mut().reset(time);
682 shutdown.timer.as_mut().poll(cx)
683 }
684 }
685
686 WorkerState::Available => loop {
688 match this.check_readiness(cx) {
689 Ok(true) => {}
690 Ok(false) => {
691 trace!("worker is unavailable");
692 this.state = WorkerState::Unavailable;
693 return self.poll(cx);
694 }
695 Err((token, idx)) => {
696 this.restart_service(token, idx);
697 return self.poll(cx);
698 }
699 }
700
701 match ready!(this.conn_rx.poll_recv(cx)) {
703 Some(msg) => {
704 let guard = this.counter.guard();
705 let _ = this.services[msg.token]
706 .service
707 .call((guard, msg.io))
708 .into_inner();
709 }
710 None => return Poll::Ready(()),
711 };
712 },
713 }
714 }
715}
716
717fn wrap_worker_services(services: Vec<(usize, usize, BoxedServerService)>) -> Vec<WorkerService> {
718 services
719 .into_iter()
720 .fold(Vec::new(), |mut services, (idx, token, service)| {
721 assert_eq!(token, services.len());
722 services.push(WorkerService {
723 factory_idx: idx,
724 service,
725 status: WorkerServiceStatus::Unavailable,
726 });
727 services
728 })
729}