1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
use std::collections::HashMap;
use std::hash::Hash;

use diesel::backend::Backend;
use diesel::connection::statement_cache::{MaybeCached, PrepareForCache, StatementCacheKey};
use diesel::query_builder::{QueryFragment, QueryId};
use diesel::QueryResult;
use futures::future::BoxFuture;
use futures::FutureExt;

#[derive(Default)]
pub struct StmtCache<DB: Backend, S> {
    cache: HashMap<StatementCacheKey<DB>, S>,
}

type PrepareFuture<'a, F, S> = futures::future::Either<
    futures::future::Ready<QueryResult<(MaybeCached<'a, S>, F)>>,
    BoxFuture<'a, QueryResult<(MaybeCached<'a, S>, F)>>,
>;

#[async_trait::async_trait]
pub trait PrepareCallback<S, M> {
    async fn prepare(
        self,
        sql: &str,
        metadata: &[M],
        is_for_cache: PrepareForCache,
    ) -> QueryResult<(S, Self)>
    where
        Self: Sized;
}

impl<S, DB: Backend> StmtCache<DB, S> {
    pub fn new() -> Self {
        Self {
            cache: HashMap::new(),
        }
    }

    pub fn cached_prepared_statement<'a, T, F>(
        &'a mut self,
        query: T,
        metadata: &[DB::TypeMetadata],
        prepare_fn: F,
        backend: &DB,
    ) -> PrepareFuture<'a, F, S>
    where
        S: Send,
        DB::QueryBuilder: Default,
        DB::TypeMetadata: Clone + Send + Sync,
        T: QueryFragment<DB> + QueryId + Send,
        F: PrepareCallback<S, DB::TypeMetadata> + Send + 'a,
        StatementCacheKey<DB>: Hash + Eq,
    {
        use std::collections::hash_map::Entry::{Occupied, Vacant};

        let cache_key = match StatementCacheKey::for_source(&query, metadata, backend) {
            Ok(key) => key,
            Err(e) => return futures::future::Either::Left(futures::future::ready(Err(e))),
        };

        let is_query_safe_to_cache = match query.is_safe_to_cache_prepared(backend) {
            Ok(is_safe_to_cache) => is_safe_to_cache,
            Err(e) => return futures::future::Either::Left(futures::future::ready(Err(e))),
        };

        if !is_query_safe_to_cache {
            let sql = match cache_key.sql(&query, backend) {
                Ok(sql) => sql.into_owned(),
                Err(e) => return futures::future::Either::Left(futures::future::ready(Err(e))),
            };

            let metadata = metadata.to_vec();
            let f = async move {
                let stmt = prepare_fn
                    .prepare(&sql, &metadata, PrepareForCache::No)
                    .await?;
                Ok((MaybeCached::CannotCache(stmt.0), stmt.1))
            }
            .boxed();
            return futures::future::Either::Right(f);
        }

        match self.cache.entry(cache_key) {
            Occupied(entry) => futures::future::Either::Left(futures::future::ready(Ok((
                MaybeCached::Cached(entry.into_mut()),
                prepare_fn,
            )))),
            Vacant(entry) => {
                let sql = match entry.key().sql(&query, backend) {
                    Ok(sql) => sql.into_owned(),
                    Err(e) => return futures::future::Either::Left(futures::future::ready(Err(e))),
                };
                let metadata = metadata.to_vec();
                let f = async move {
                    let statement = prepare_fn
                        .prepare(&sql, &metadata, PrepareForCache::Yes)
                        .await?;

                    Ok((MaybeCached::Cached(entry.insert(statement.0)), statement.1))
                }
                .boxed();
                futures::future::Either::Right(f)
            }
        }
    }
}