diesel_async/
stmt_cache.rs1use std::collections::HashMap;
2use std::hash::Hash;
3
4use diesel::backend::Backend;
5use diesel::connection::statement_cache::{MaybeCached, PrepareForCache, StatementCacheKey};
6use diesel::connection::Instrumentation;
7use diesel::connection::InstrumentationEvent;
8use diesel::QueryResult;
9use futures_util::{future, FutureExt};
10
11#[derive(Default)]
12pub struct StmtCache<DB: Backend, S> {
13 cache: HashMap<StatementCacheKey<DB>, S>,
14}
15
16type PrepareFuture<'a, F, S> = future::Either<
17 future::Ready<QueryResult<(MaybeCached<'a, S>, F)>>,
18 future::BoxFuture<'a, QueryResult<(MaybeCached<'a, S>, F)>>,
19>;
20
21#[async_trait::async_trait]
22pub trait PrepareCallback<S, M>: Sized {
23 async fn prepare(
24 self,
25 sql: &str,
26 metadata: &[M],
27 is_for_cache: PrepareForCache,
28 ) -> QueryResult<(S, Self)>;
29}
30
31impl<S, DB: Backend> StmtCache<DB, S> {
32 pub fn new() -> Self {
33 Self {
34 cache: HashMap::new(),
35 }
36 }
37
38 pub fn cached_prepared_statement<'a, F>(
39 &'a mut self,
40 cache_key: StatementCacheKey<DB>,
41 sql: String,
42 is_query_safe_to_cache: bool,
43 metadata: &[DB::TypeMetadata],
44 prepare_fn: F,
45 instrumentation: &std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
46 ) -> PrepareFuture<'a, F, S>
47 where
48 S: Send,
49 DB::QueryBuilder: Default,
50 DB::TypeMetadata: Clone + Send + Sync,
51 F: PrepareCallback<S, DB::TypeMetadata> + Send + 'a,
52 StatementCacheKey<DB>: Hash + Eq,
53 {
54 use std::collections::hash_map::Entry::{Occupied, Vacant};
55
56 if !is_query_safe_to_cache {
57 let metadata = metadata.to_vec();
58 let f = async move {
59 let stmt = prepare_fn
60 .prepare(&sql, &metadata, PrepareForCache::No)
61 .await?;
62 Ok((MaybeCached::CannotCache(stmt.0), stmt.1))
63 }
64 .boxed();
65 return future::Either::Right(f);
66 }
67
68 match self.cache.entry(cache_key) {
69 Occupied(entry) => future::Either::Left(future::ready(Ok((
70 MaybeCached::Cached(entry.into_mut()),
71 prepare_fn,
72 )))),
73 Vacant(entry) => {
74 let metadata = metadata.to_vec();
75 instrumentation
76 .lock()
77 .unwrap_or_else(|p| p.into_inner())
78 .on_connection_event(InstrumentationEvent::cache_query(&sql));
79 let f = async move {
80 let statement = prepare_fn
81 .prepare(&sql, &metadata, PrepareForCache::Yes)
82 .await?;
83
84 Ok((MaybeCached::Cached(entry.insert(statement.0)), statement.1))
85 }
86 .boxed();
87 future::Either::Right(f)
88 }
89 }
90 }
91}