diesel_async/
stmt_cache.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3
4use diesel::backend::Backend;
5use diesel::connection::statement_cache::{MaybeCached, PrepareForCache, StatementCacheKey};
6use diesel::connection::Instrumentation;
7use diesel::connection::InstrumentationEvent;
8use diesel::QueryResult;
9use futures_util::{future, FutureExt};
10
11#[derive(Default)]
12pub struct StmtCache<DB: Backend, S> {
13    cache: HashMap<StatementCacheKey<DB>, S>,
14}
15
16type PrepareFuture<'a, F, S> = future::Either<
17    future::Ready<QueryResult<(MaybeCached<'a, S>, F)>>,
18    future::BoxFuture<'a, QueryResult<(MaybeCached<'a, S>, F)>>,
19>;
20
21#[async_trait::async_trait]
22pub trait PrepareCallback<S, M>: Sized {
23    async fn prepare(
24        self,
25        sql: &str,
26        metadata: &[M],
27        is_for_cache: PrepareForCache,
28    ) -> QueryResult<(S, Self)>;
29}
30
31impl<S, DB: Backend> StmtCache<DB, S> {
32    pub fn new() -> Self {
33        Self {
34            cache: HashMap::new(),
35        }
36    }
37
38    pub fn cached_prepared_statement<'a, F>(
39        &'a mut self,
40        cache_key: StatementCacheKey<DB>,
41        sql: String,
42        is_query_safe_to_cache: bool,
43        metadata: &[DB::TypeMetadata],
44        prepare_fn: F,
45        instrumentation: &std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
46    ) -> PrepareFuture<'a, F, S>
47    where
48        S: Send,
49        DB::QueryBuilder: Default,
50        DB::TypeMetadata: Clone + Send + Sync,
51        F: PrepareCallback<S, DB::TypeMetadata> + Send + 'a,
52        StatementCacheKey<DB>: Hash + Eq,
53    {
54        use std::collections::hash_map::Entry::{Occupied, Vacant};
55
56        if !is_query_safe_to_cache {
57            let metadata = metadata.to_vec();
58            let f = async move {
59                let stmt = prepare_fn
60                    .prepare(&sql, &metadata, PrepareForCache::No)
61                    .await?;
62                Ok((MaybeCached::CannotCache(stmt.0), stmt.1))
63            }
64            .boxed();
65            return future::Either::Right(f);
66        }
67
68        match self.cache.entry(cache_key) {
69            Occupied(entry) => future::Either::Left(future::ready(Ok((
70                MaybeCached::Cached(entry.into_mut()),
71                prepare_fn,
72            )))),
73            Vacant(entry) => {
74                let metadata = metadata.to_vec();
75                instrumentation
76                    .lock()
77                    .unwrap_or_else(|p| p.into_inner())
78                    .on_connection_event(InstrumentationEvent::cache_query(&sql));
79                let f = async move {
80                    let statement = prepare_fn
81                        .prepare(&sql, &metadata, PrepareForCache::Yes)
82                        .await?;
83
84                    Ok((MaybeCached::Cached(entry.insert(statement.0)), statement.1))
85                }
86                .boxed();
87                future::Either::Right(f)
88            }
89        }
90    }
91}