1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
use std::collections::HashMap;
use std::hash::Hash;

use diesel::backend::Backend;
use diesel::connection::statement_cache::{MaybeCached, PrepareForCache, StatementCacheKey};
use diesel::connection::Instrumentation;
use diesel::connection::InstrumentationEvent;
use diesel::QueryResult;
use futures_util::{future, FutureExt};

#[derive(Default)]
pub struct StmtCache<DB: Backend, S> {
    cache: HashMap<StatementCacheKey<DB>, S>,
}

type PrepareFuture<'a, F, S> = future::Either<
    future::Ready<QueryResult<(MaybeCached<'a, S>, F)>>,
    future::BoxFuture<'a, QueryResult<(MaybeCached<'a, S>, F)>>,
>;

#[async_trait::async_trait]
pub trait PrepareCallback<S, M>: Sized {
    async fn prepare(
        self,
        sql: &str,
        metadata: &[M],
        is_for_cache: PrepareForCache,
    ) -> QueryResult<(S, Self)>;
}

impl<S, DB: Backend> StmtCache<DB, S> {
    pub fn new() -> Self {
        Self {
            cache: HashMap::new(),
        }
    }

    pub fn cached_prepared_statement<'a, F>(
        &'a mut self,
        cache_key: StatementCacheKey<DB>,
        sql: String,
        is_query_safe_to_cache: bool,
        metadata: &[DB::TypeMetadata],
        prepare_fn: F,
        instrumentation: &std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
    ) -> PrepareFuture<'a, F, S>
    where
        S: Send,
        DB::QueryBuilder: Default,
        DB::TypeMetadata: Clone + Send + Sync,
        F: PrepareCallback<S, DB::TypeMetadata> + Send + 'a,
        StatementCacheKey<DB>: Hash + Eq,
    {
        use std::collections::hash_map::Entry::{Occupied, Vacant};

        if !is_query_safe_to_cache {
            let metadata = metadata.to_vec();
            let f = async move {
                let stmt = prepare_fn
                    .prepare(&sql, &metadata, PrepareForCache::No)
                    .await?;
                Ok((MaybeCached::CannotCache(stmt.0), stmt.1))
            }
            .boxed();
            return future::Either::Right(f);
        }

        match self.cache.entry(cache_key) {
            Occupied(entry) => future::Either::Left(future::ready(Ok((
                MaybeCached::Cached(entry.into_mut()),
                prepare_fn,
            )))),
            Vacant(entry) => {
                let metadata = metadata.to_vec();
                instrumentation
                    .lock()
                    .unwrap_or_else(|p| p.into_inner())
                    .on_connection_event(InstrumentationEvent::cache_query(&sql));
                let f = async move {
                    let statement = prepare_fn
                        .prepare(&sql, &metadata, PrepareForCache::Yes)
                        .await?;

                    Ok((MaybeCached::Cached(entry.insert(statement.0)), statement.1))
                }
                .boxed();
                future::Either::Right(f)
            }
        }
    }
}