1
use async_trait::async_trait;
2
use std::fmt::Display;
3

            
4
use canyon_connection::canyon_database_connector::DatabaseConnection;
5
use canyon_connection::{get_database_connection, CACHED_DATABASE_CONN};
6

            
7
use crate::bounds::QueryParameter;
8
use crate::mapper::RowMapper;
9
use crate::query_elements::query_builder::{
10
    DeleteQueryBuilder, SelectQueryBuilder, UpdateQueryBuilder,
11
};
12
use crate::rows::CanyonRows;
13

            
14
#[cfg(feature = "mysql")]
15
pub const DETECT_PARAMS_IN_QUERY: &str = r"\$([\d])+";
16
#[cfg(feature = "mysql")]
17
pub const DETECT_QUOTE_IN_QUERY: &str = r#"\"|\\"#;
18

            
19
/// This traits defines and implements a query against a database given
20
/// an statement `stmt` and the params to pass the to the client.
21
///
22
/// Returns [`std::result::Result`] of [`CanyonRows`], which is the core Canyon type to wrap
23
/// the result of the query provide automatic mappings and deserialization
24
#[async_trait]
25
pub trait Transaction<T> {
26
    /// Performs a query against the targeted database by the selected or
27
    /// the defaulted datasource, wrapping the resultant collection of entities
28
    /// in [`super::rows::CanyonRows`]
29
92
    async fn query<'a, S, Z>(
30
92
        stmt: S,
31
92
        params: Z,
32
        datasource_name: &'a str,
33
    ) -> Result<CanyonRows<T>, Box<(dyn std::error::Error + Sync + Send + 'static)>>
34
    where
35
        S: AsRef<str> + Display + Sync + Send + 'a,
36
        Z: AsRef<[&'a dyn QueryParameter<'a>]> + Sync + Send + 'a,
37
532
    {
38
92
        let mut guarded_cache = CACHED_DATABASE_CONN.lock().await;
39
92
        let database_conn = get_database_connection(datasource_name, &mut guarded_cache);
40

            
41
92
        match *database_conn {
42
            #[cfg(feature = "postgres")]
43
            DatabaseConnection::Postgres(_) => {
44
96
                postgres_query_launcher::launch::<T>(
45
                    database_conn,
46
32
                    stmt.to_string(),
47
32
                    params.as_ref(),
48
                )
49
129
                .await
50
            }
51
            #[cfg(feature = "mssql")]
52
            DatabaseConnection::SqlServer(_) => {
53
93
                sqlserver_query_launcher::launch::<T, Z>(
54
                    database_conn,
55
31
                    &mut stmt.to_string(),
56
31
                    params,
57
                )
58
93
                .await
59
31
            }
60
            #[cfg(feature = "mysql")]
61
            DatabaseConnection::MySQL(_) => {
62
87
                mysql_query_launcher::launch::<T>(database_conn, stmt.to_string(), params.as_ref())
63
218
                    .await
64
            }
65
        }
66
368
    }
67
}
68

            
69
/// *CrudOperations* it's the core part of Canyon-SQL.
70
///
71
/// Here it's defined and implemented every CRUD operation
72
/// that the user has available, just by deriving the `CanyonCrud`
73
/// derive macro when a struct contains the annotation.
74
///
75
/// Also, this traits needs that the type T over what it's generified
76
/// to implement certain types in order to work correctly.
77
///
78
/// The most notorious one it's the [`RowMapper<T>`] one, which allows
79
/// Canyon to directly maps database results into structs.
80
///
81
/// See it's definition and docs to see the implementations.
82
/// Also, you can find the written macro-code that performs the auto-mapping
83
/// in the *canyon_sql_root::canyon_macros* crates, on the root of this project.
84
#[async_trait]
85
pub trait CrudOperations<T>: Transaction<T>
86
where
87
    T: CrudOperations<T> + RowMapper<T>,
88
{
89
    async fn find_all<'a>() -> Result<Vec<T>, Box<(dyn std::error::Error + Send + Sync + 'static)>>;
90

            
91
    async fn find_all_datasource<'a>(
92
        datasource_name: &'a str,
93
    ) -> Result<Vec<T>, Box<(dyn std::error::Error + Send + Sync + 'static)>>;
94

            
95
    async fn find_all_unchecked<'a>() -> Vec<T>;
96

            
97
    async fn find_all_unchecked_datasource<'a>(datasource_name: &'a str) -> Vec<T>;
98

            
99
    fn select_query<'a>() -> SelectQueryBuilder<'a, T>;
100

            
101
    fn select_query_datasource(datasource_name: &str) -> SelectQueryBuilder<'_, T>;
102

            
103
    async fn count() -> Result<i64, Box<(dyn std::error::Error + Send + Sync + 'static)>>;
104

            
105
    async fn count_datasource<'a>(
106
        datasource_name: &'a str,
107
    ) -> Result<i64, Box<(dyn std::error::Error + Send + Sync + 'static)>>;
108

            
109
    async fn find_by_pk<'a>(
110
        value: &'a dyn QueryParameter<'a>,
111
    ) -> Result<Option<T>, Box<(dyn std::error::Error + Send + Sync + 'static)>>;
112

            
113
    async fn find_by_pk_datasource<'a>(
114
        value: &'a dyn QueryParameter<'a>,
115
        datasource_name: &'a str,
116
    ) -> Result<Option<T>, Box<(dyn std::error::Error + Send + Sync + 'static)>>;
117

            
118
    async fn insert<'a>(&mut self) -> Result<(), Box<dyn std::error::Error + Sync + Send>>;
119

            
120
    async fn insert_datasource<'a>(
121
        &mut self,
122
        datasource_name: &'a str,
123
    ) -> Result<(), Box<dyn std::error::Error + Sync + Send>>;
124

            
125
    async fn multi_insert<'a>(
126
        instances: &'a mut [&'a mut T],
127
    ) -> Result<(), Box<(dyn std::error::Error + Send + Sync + 'static)>>;
128

            
129
    async fn multi_insert_datasource<'a>(
130
        instances: &'a mut [&'a mut T],
131
        datasource_name: &'a str,
132
    ) -> Result<(), Box<(dyn std::error::Error + Send + Sync + 'static)>>;
133

            
134
    async fn update(&self) -> Result<(), Box<dyn std::error::Error + Sync + Send>>;
135

            
136
    async fn update_datasource<'a>(
137
        &self,
138
        datasource_name: &'a str,
139
    ) -> Result<(), Box<dyn std::error::Error + Sync + Send>>;
140

            
141
    fn update_query<'a>() -> UpdateQueryBuilder<'a, T>;
142

            
143
    fn update_query_datasource(datasource_name: &str) -> UpdateQueryBuilder<'_, T>;
144

            
145
    async fn delete(&self) -> Result<(), Box<dyn std::error::Error + Sync + Send>>;
146

            
147
    async fn delete_datasource<'a>(
148
        &self,
149
        datasource_name: &'a str,
150
    ) -> Result<(), Box<dyn std::error::Error + Sync + Send>>;
151

            
152
    fn delete_query<'a>() -> DeleteQueryBuilder<'a, T>;
153

            
154
    fn delete_query_datasource(datasource_name: &str) -> DeleteQueryBuilder<'_, T>;
155
}
156

            
157
#[cfg(feature = "postgres")]
158
mod postgres_query_launcher {
159
    use canyon_connection::canyon_database_connector::DatabaseConnection;
160

            
161
    use crate::bounds::QueryParameter;
162
    use crate::rows::CanyonRows;
163

            
164
32
    pub async fn launch<'a, T>(
165
32
        db_conn: &DatabaseConnection,
166
32
        stmt: String,
167
32
        params: &'a [&'_ dyn QueryParameter<'_>],
168
161
    ) -> Result<CanyonRows<T>, Box<(dyn std::error::Error + Send + Sync + 'static)>> {
169
32
        let mut m_params = Vec::new();
170
94
        for param in params {
171
62
            m_params.push(param.as_postgres_param());
172
        }
173

            
174
160
        let r = db_conn
175
            .postgres_connection()
176
            .client
177
32
            .query(&stmt, m_params.as_slice())
178
129
            .await?;
179

            
180
32
        Ok(CanyonRows::Postgres(r))
181
64
    }
182
}
183

            
184
#[cfg(feature = "mssql")]
185
mod sqlserver_query_launcher {
186
    use crate::rows::CanyonRows;
187
    use crate::{
188
        bounds::QueryParameter,
189
        canyon_connection::{canyon_database_connector::DatabaseConnection, tiberius::Query},
190
    };
191

            
192
31
    pub async fn launch<'a, T, Z>(
193
31
        db_conn: &mut DatabaseConnection,
194
31
        stmt: &mut String,
195
31
        params: Z,
196
    ) -> Result<CanyonRows<T>, Box<(dyn std::error::Error + Send + Sync + 'static)>>
197
    where
198
        Z: AsRef<[&'a dyn QueryParameter<'a>]> + Sync + Send + 'a,
199
124
    {
200
        // Re-generate de insert statement to adequate it to the SQL SERVER syntax to retrieve the PK value(s) after insert
201
31
        if stmt.contains("RETURNING") {
202
5
            let c = stmt.clone();
203
5
            let temp = c.split_once("RETURNING").unwrap();
204
5
            let temp2 = temp.0.split_once("VALUES").unwrap();
205

            
206
15
            *stmt = format!(
207
                "{} OUTPUT inserted.{} VALUES {}",
208
5
                temp2.0.trim(),
209
5
                temp.1.trim(),
210
5
                temp2.1.trim()
211
            );
212
5
        }
213

            
214
31
        let mut mssql_query = Query::new(stmt.to_owned().replace('$', "@P"));
215
62
        params
216
            .as_ref()
217
            .iter()
218
92
            .for_each(|param| mssql_query.bind(*param));
219

            
220
217
        let _results = mssql_query
221
31
            .query(db_conn.sqlserver_connection().client)
222
93
            .await?
223
            .into_results()
224
62
            .await?;
225

            
226
31
        Ok(CanyonRows::Tiberius(
227
31
            _results.into_iter().flatten().collect(),
228
        ))
229
62
    }
230
}
231

            
232
#[cfg(feature = "mysql")]
233
mod mysql_query_launcher {
234
    use std::sync::Arc;
235

            
236
    use mysql_async::prelude::Query;
237
    use mysql_async::QueryWithParams;
238
    use mysql_async::Value;
239

            
240
    use canyon_connection::canyon_database_connector::DatabaseConnection;
241

            
242
    use crate::bounds::QueryParameter;
243
    use crate::rows::CanyonRows;
244
    use mysql_async::Row;
245
    use mysql_common::constants::ColumnType;
246
    use mysql_common::row;
247

            
248
    use super::reorder_params;
249
    use crate::crud::{DETECT_PARAMS_IN_QUERY, DETECT_QUOTE_IN_QUERY};
250
    use regex::Regex;
251

            
252
29
    pub async fn launch<'a, T>(
253
29
        db_conn: &DatabaseConnection,
254
29
        stmt: String,
255
29
        params: &'a [&'_ dyn QueryParameter<'_>],
256
247
    ) -> Result<CanyonRows<T>, Box<(dyn std::error::Error + Send + Sync + 'static)>> {
257
131
        let mysql_connection = db_conn.mysql_connection().client.get_conn().await?;
258

            
259
29
        let stmt_with_escape_characters = regex::escape(&stmt);
260
        let query_string =
261
29
            Regex::new(DETECT_PARAMS_IN_QUERY)?.replace_all(&stmt_with_escape_characters, "?");
262

            
263
29
        let mut query_string = Regex::new(DETECT_QUOTE_IN_QUERY)?
264
29
            .replace_all(&query_string, "")
265
29
            .to_string();
266

            
267
29
        let mut is_insert = false;
268
29
        if let Some(index_start_clausule_returning) = query_string.find(" RETURNING") {
269
5
            query_string.truncate(index_start_clausule_returning);
270
5
            is_insert = true;
271
        }
272

            
273
        let params_query: Vec<Value> =
274
90
            reorder_params(&stmt, params, |f| f.as_mysql_param().to_value());
275

            
276
29
        let query_with_params = QueryWithParams {
277
29
            query: query_string,
278
            params: params_query,
279
        };
280

            
281
116
        let mut query_result = query_with_params
282
29
            .run(mysql_connection)
283
116
            .await
284
            .expect("Error executing query in mysql");
285

            
286
29
        let result_rows = if is_insert {
287
5
            let last_insert = query_result
288
                .last_insert_id()
289
                .map(Value::UInt)
290
                .expect("Error getting pk id in insert");
291

            
292
5
            vec![row::new_row(
293
5
                vec![last_insert],
294
5
                Arc::new([mysql_async::Column::new(ColumnType::MYSQL_TYPE_UNKNOWN)]),
295
            )]
296
        } else {
297
96
            query_result
298
                .collect::<Row>()
299
48
                .await
300
                .expect("Error resolved trait FromRow in mysql")
301
        };
302

            
303
29
        Ok(CanyonRows::MySQL(result_rows))
304
58
    }
305
}
306

            
307
#[cfg(feature = "mysql")]
308
29
fn reorder_params<T>(
309
    stmt: &str,
310
    params: &[&'_ dyn QueryParameter<'_>],
311
    fn_parser: impl Fn(&&dyn QueryParameter<'_>) -> T,
312
) -> Vec<T> {
313
29
    let mut ordered_params = vec![];
314
29
    let rg = regex::Regex::new(DETECT_PARAMS_IN_QUERY)
315
        .expect("Error create regex with detect params pattern expression");
316

            
317
90
    for positional_param in rg.find_iter(stmt) {
318
61
        let pp: &str = positional_param.as_str();
319
61
        let pp_index = pp[1..] // param $1 -> get 1
320
            .parse::<usize>()
321
            .expect("Error parse mapped parameter to usized.")
322
            - 1;
323

            
324
61
        let element = params
325
            .get(pp_index)
326
            .expect("Error obtaining the element of the mapping against parameters.");
327
61
        ordered_params.push(fn_parser(element));
328
    }
329

            
330
29
    ordered_params
331
29
}