1
use serde::Deserialize;
2

            
3
#[cfg(feature = "mssql")]
4
use async_std::net::TcpStream;
5
#[cfg(feature = "mysql")]
6
use mysql_async::Pool;
7
#[cfg(feature = "mssql")]
8
use tiberius::{AuthMethod, Config};
9
#[cfg(feature = "postgres")]
10
use tokio_postgres::{Client, NoTls};
11

            
12
use crate::datasources::{Auth, DatasourceConfig};
13

            
14
/// Represents the current supported databases by Canyon
15
#[derive(Deserialize, Debug, Eq, PartialEq, Clone, Copy)]
16
pub enum DatabaseType {
17
    #[serde(alias = "postgres", alias = "postgresql")]
18
    #[cfg(feature = "postgres")]
19
    PostgreSql,
20
    #[serde(alias = "sqlserver", alias = "mssql")]
21
    #[cfg(feature = "mssql")]
22
    SqlServer,
23
    #[serde(alias = "mysql")]
24
    #[cfg(feature = "mysql")]
25
    MySQL,
26
}
27

            
28
impl From<&Auth> for DatabaseType {
29
30
    fn from(value: &Auth) -> Self {
30
30
        match value {
31
15
            crate::datasources::Auth::Postgres(_) => DatabaseType::PostgreSql,
32
7
            crate::datasources::Auth::SqlServer(_) => DatabaseType::SqlServer,
33
8
            crate::datasources::Auth::MySQL(_) => DatabaseType::MySQL,
34
        }
35
30
    }
36
}
37

            
38
/// A connection with a `PostgreSQL` database
39
#[cfg(feature = "postgres")]
40
pub struct PostgreSqlConnection {
41
    pub client: Client,
42
    // pub connection: Connection<Socket, NoTlsStream>, // TODO Hold it, or not to hold it... that's the question!
43
}
44

            
45
/// A connection with a `SqlServer` database
46
#[cfg(feature = "mssql")]
47
pub struct SqlServerConnection {
48
    pub client: &'static mut tiberius::Client<TcpStream>,
49
}
50

            
51
/// A connection with a `Mysql` database
52
#[cfg(feature = "mysql")]
53
pub struct MysqlConnection {
54
    pub client: Pool,
55
}
56

            
57
/// The Canyon database connection handler. When the client's program
58
/// starts, Canyon gets the information about the desired datasources,
59
/// process them and generates a pool of 1 to 1 database connection for
60
/// every datasource defined.
61
pub enum DatabaseConnection {
62
    #[cfg(feature = "postgres")]
63
    Postgres(PostgreSqlConnection),
64
    #[cfg(feature = "mssql")]
65
    SqlServer(SqlServerConnection),
66
    #[cfg(feature = "mysql")]
67
    MySQL(MysqlConnection),
68
}
69

            
70
unsafe impl Send for DatabaseConnection {}
71
unsafe impl Sync for DatabaseConnection {}
72

            
73
impl DatabaseConnection {
74
168
    pub async fn new(
75
168
        datasource: &DatasourceConfig,
76
958
    ) -> Result<DatabaseConnection, Box<(dyn std::error::Error + Send + Sync + 'static)>> {
77
168
        match datasource.get_db_type() {
78
            #[cfg(feature = "postgres")]
79
            DatabaseType::PostgreSql => {
80
56
                let (username, password) = match &datasource.auth {
81
56
                    crate::datasources::Auth::Postgres(postgres_auth) => match postgres_auth {
82
56
                        crate::datasources::PostgresAuth::Basic { username, password } => {
83
56
                            (username.as_str(), password.as_str())
84
                        }
85
                    },
86
                    #[cfg(feature = "mssql")]
87
                    crate::datasources::Auth::SqlServer(_) => {
88
                        panic!("Found SqlServer auth configuration for a PostgreSQL datasource")
89
                    }
90
                    #[cfg(feature = "mysql")]
91
                    crate::datasources::Auth::MySQL(_) => {
92
                        panic!("Found MySql auth configuration for a PostgreSQL datasource")
93
                    }
94
                };
95
224
                let (new_client, new_connection) = tokio_postgres::connect(
96
112
                    &format!(
97
                        "postgres://{user}:{pswd}@{host}:{port}/{db}",
98
                        user = username,
99
                        pswd = password,
100
                        host = datasource.properties.host,
101
56
                        port = datasource.properties.port.unwrap_or_default(),
102
                        db = datasource.properties.db_name
103
56
                    )[..],
104
                    NoTls,
105
                )
106
378
                .await?;
107

            
108
293
                tokio::spawn(async move {
109
237
                    if let Err(e) = new_connection.await {
110
                        eprintln!("An error occurred while trying to connect to the PostgreSQL database: {e}");
111
                    }
112
56
                });
113

            
114
56
                Ok(DatabaseConnection::Postgres(PostgreSqlConnection {
115
                    client: new_client,
116
                    // connection: new_connection,
117
                }))
118
56
            }
119
            #[cfg(feature = "mssql")]
120
            DatabaseType::SqlServer => {
121
56
                let mut config = Config::new();
122

            
123
56
                config.host(&datasource.properties.host);
124
56
                config.port(datasource.properties.port.unwrap_or_default());
125
56
                config.database(&datasource.properties.db_name);
126

            
127
                // Using SQL Server authentication.
128
112
                config.authentication(match &datasource.auth {
129
                    #[cfg(feature = "postgres")]
130
                    crate::datasources::Auth::Postgres(_) => {
131
                        panic!("Found PostgreSQL auth configuration for a SqlServer database")
132
                    }
133
56
                    crate::datasources::Auth::SqlServer(sql_server_auth) => match sql_server_auth {
134
56
                        crate::datasources::SqlServerAuth::Basic { username, password } => {
135
56
                            AuthMethod::sql_server(username, password)
136
                        }
137
                        crate::datasources::SqlServerAuth::Integrated => AuthMethod::Integrated,
138
                    },
139
                    #[cfg(feature = "mysql")]
140
                    crate::datasources::Auth::MySQL(_) => {
141
                        panic!("Found PostgreSQL auth configuration for a SqlServer database")
142
                    }
143
                });
144

            
145
                // on production, it is not a good idea to do this. We should upgrade
146
                // Canyon in future versions to allow the user take care about this
147
                // configuration
148
56
                config.trust_cert();
149

            
150
                // Taking the address from the configuration, using async-std's
151
                // TcpStream to connect to the server.
152
224
                let tcp = TcpStream::connect(config.get_addr())
153
244
                    .await
154
                    .expect("Error instantiating the SqlServer TCP Stream");
155

            
156
                // We'll disable the Nagle algorithm. Buffering is handled
157
                // internally with a `Sink`.
158
56
                tcp.set_nodelay(true)
159
                    .expect("Error in the SqlServer `nodelay` config");
160

            
161
                // Handling TLS, login and other details related to the SQL Server.
162
280
                let client = tiberius::Client::connect(config, tcp).await;
163

            
164
56
                Ok(DatabaseConnection::SqlServer(SqlServerConnection {
165
56
                    client: Box::leak(Box::new(
166
56
                        client.expect("A failure happened connecting to the database"),
167
                    )),
168
                }))
169
56
            }
170
            #[cfg(feature = "mysql")]
171
            DatabaseType::MySQL => {
172
56
                let (user, password) = match &datasource.auth {
173
                    #[cfg(feature = "mssql")]
174
                    crate::datasources::Auth::SqlServer(_) => {
175
                        panic!("Found SqlServer auth configuration for a PostgreSQL datasource")
176
                    }
177
                    #[cfg(feature = "postgres")]
178
                    crate::datasources::Auth::Postgres(_) => {
179
                        panic!("Found MySql auth configuration for a PostgreSQL datasource")
180
                    }
181
                    #[cfg(feature = "mysql")]
182
56
                    crate::datasources::Auth::MySQL(mysql_auth) => match mysql_auth {
183
56
                        crate::datasources::MySQLAuth::Basic { username, password } => {
184
56
                            (username, password)
185
                        }
186
                    },
187
                };
188

            
189
                //TODO add options to optionals params in url
190

            
191
112
                let url = format!(
192
                    "mysql://{}:{}@{}:{}/{}",
193
                    user,
194
                    password,
195
                    datasource.properties.host,
196
56
                    datasource.properties.port.unwrap_or_default(),
197
                    datasource.properties.db_name
198
                );
199
56
                let mysql_connection = Pool::from_url(url)?;
200

            
201
56
                Ok(DatabaseConnection::MySQL(MysqlConnection {
202
                    client: { mysql_connection },
203
                }))
204
56
            }
205
        }
206
336
    }
207

            
208
    #[cfg(feature = "postgres")]
209
32
    pub fn postgres_connection(&self) -> &PostgreSqlConnection {
210
32
        match self {
211
32
            DatabaseConnection::Postgres(conn) => conn,
212
            #[cfg(all(feature = "postgres", feature = "mssql", feature = "mysql"))]
213
            _ => panic!(),
214
        }
215
32
    }
216

            
217
    #[cfg(feature = "mssql")]
218
31
    pub fn sqlserver_connection(&mut self) -> &mut SqlServerConnection {
219
31
        match self {
220
31
            DatabaseConnection::SqlServer(conn) => conn,
221
            #[cfg(all(feature = "postgres", feature = "mssql", feature = "mysql"))]
222
            _ => panic!(),
223
        }
224
31
    }
225

            
226
    #[cfg(feature = "mysql")]
227
29
    pub fn mysql_connection(&self) -> &MysqlConnection {
228
29
        match self {
229
29
            DatabaseConnection::MySQL(conn) => conn,
230
            #[cfg(all(feature = "postgres", feature = "mssql", feature = "mysql"))]
231
            _ => panic!(),
232
        }
233
29
    }
234
}
235

            
236
#[cfg(test)]
237
mod database_connection_handler {
238
    use super::*;
239
    use crate::CanyonSqlConfig;
240

            
241
    /// Tests the behaviour of the `DatabaseType::from_datasource(...)`
242
    #[test]
243
    fn check_from_datasource() {
244
        #[cfg(all(feature = "postgres", feature = "mssql", feature = "mysql"))]
245
        {
246
            const CONFIG_FILE_MOCK_ALT_ALL: &str = r#"
247
                [canyon_sql]
248
                datasources = [
249
                    {name = 'PostgresDS', auth = { postgresql = { basic = { username = "postgres", password = "postgres" } } }, properties.host = 'localhost', properties.db_name = 'triforce', properties.migrations='enabled' },
250
                    {name = 'SqlServerDS', auth = { sqlserver = { basic = { username = "sa", password = "SqlServer-10" } } }, properties.host = '192.168.0.250.1', properties.port = 3340, properties.db_name = 'triforce2', properties.migrations='disabled' },
251
                    {name = 'MysqlDS', auth = { mysql = { basic = { username = "root", password = "root" } } }, properties.host = '192.168.0.250.1', properties.port = 3340, properties.db_name = 'triforce2', properties.migrations='disabled' }
252
                ]
253
            "#;
254
            let config: CanyonSqlConfig = toml::from_str(CONFIG_FILE_MOCK_ALT_ALL)
255
                .expect("A failure happened retrieving the [canyon_sql] section");
256
            assert_eq!(
257
                config.canyon_sql.datasources[0].get_db_type(),
258
                DatabaseType::PostgreSql
259
            );
260
            assert_eq!(
261
                config.canyon_sql.datasources[1].get_db_type(),
262
                DatabaseType::SqlServer
263
            );
264
            assert_eq!(
265
                config.canyon_sql.datasources[2].get_db_type(),
266
                DatabaseType::MySQL
267
            );
268
        }
269

            
270
        #[cfg(feature = "postgres")]
271
        {
272
            const CONFIG_FILE_MOCK_ALT_PG: &str = r#"
273
                [canyon_sql]
274
                datasources = [
275
                    {name = 'PostgresDS', auth = { postgresql = { basic = { username = "postgres", password = "postgres" } } }, properties.host = 'localhost', properties.db_name = 'triforce', properties.migrations='enabled' },
276
                ]
277
            "#;
278
            let config: CanyonSqlConfig = toml::from_str(CONFIG_FILE_MOCK_ALT_PG)
279
                .expect("A failure happened retrieving the [canyon_sql] section");
280
            assert_eq!(
281
                config.canyon_sql.datasources[0].get_db_type(),
282
                DatabaseType::PostgreSql
283
            );
284
        }
285

            
286
        #[cfg(feature = "mssql")]
287
        {
288
            const CONFIG_FILE_MOCK_ALT_MSSQL: &str = r#"
289
                [canyon_sql]
290
                datasources = [
291
                    {name = 'SqlServerDS', auth = { sqlserver = { basic = { username = "sa", password = "SqlServer-10" } } }, properties.host = '192.168.0.250.1', properties.port = 3340, properties.db_name = 'triforce2', properties.migrations='disabled' }
292
                ]
293
            "#;
294
            let config: CanyonSqlConfig = toml::from_str(CONFIG_FILE_MOCK_ALT_MSSQL)
295
                .expect("A failure happened retrieving the [canyon_sql] section");
296
            assert_eq!(
297
                config.canyon_sql.datasources[0].get_db_type(),
298
                DatabaseType::SqlServer
299
            );
300
        }
301

            
302
        #[cfg(feature = "mysql")]
303
        {
304
            const CONFIG_FILE_MOCK_ALT_MYSQL: &str = r#"
305
                [canyon_sql]
306
                datasources = [
307
                    {name = 'MysqlDS', auth = { mysql = { basic = { username = "root", password = "root" } } }, properties.host = '192.168.0.250.1', properties.port = 3340, properties.db_name = 'triforce2', properties.migrations='disabled' }
308
                ]
309
            "#;
310

            
311
            let config: CanyonSqlConfig = toml::from_str(CONFIG_FILE_MOCK_ALT_MYSQL)
312
                .expect("A failure happened retrieving the [canyon_sql] section");
313
            assert_eq!(
314
                config.canyon_sql.datasources[0].get_db_type(),
315
                DatabaseType::MySQL
316
            );
317
        }
318
    }
319
}