diff --git a/Cargo.toml b/Cargo.toml index 106dcdc..9731bef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,10 @@ edition = "2021" clap = { version = "4.5.32", features = ["derive", "env"] } color-eyre = "0.6.3" figment = { version = "0.10.19", features = ["toml"] } -sea-schema = "0.16.1" +sea-orm-codegen = "1.1.8" +sea-schema = { version = "0.16.1", features = ["sqlx-all"] } serde = { version = "1.0.219", features = ["derive"] } +sqlx = { version = "0.8.3", features = ["mysql", "postgres", "sqlite"] } syn = { version = "2.0.100", features = ["extra-traits", "full"] } tokio = { version = "1.44.1", features = ["full"] } tracing = "0.1.41" diff --git a/src/generate.rs b/src/generate.rs index 3c6c621..57b9abc 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -1,16 +1,37 @@ +use core::time; + use color_eyre::{ - eyre::{eyre, ContextCompat}, + eyre::{eyre, ContextCompat, Report}, Result, }; +use sea_schema::sea_query::TableCreateStatement; use url::Url; -pub async fn generate(database_url: String) -> Result<()> { + +use crate::Config; +pub async fn get_tables( + database_url: String, + config: Config, +) -> Result<(Option, Vec)> { let url = Url::parse(&database_url)?; tracing::trace!(?url); let is_sqlite = url.scheme() == "sqlite"; + let filter_tables = |table: &String| -> bool { + config.sea_orm.table.only.is_empty() || config.sea_orm.table.only.contains(table) + }; - let database_name = (if !is_sqlite { + let filter_hidden_tables = |table: &str| -> bool { + if false { + true + } else { + !table.starts_with('_') + } + }; + + let filter_skip_tables = + |table: &String| -> bool { !config.sea_orm.table.exclude.contains(table) }; + let database_name: &str = (if !is_sqlite { let database_name = url .path_segments() .context("No database name as part of path")? @@ -20,10 +41,129 @@ pub async fn generate(database_url: String) -> Result<()> { if database_name.is_empty() { return Err(eyre!("Database path name is empty")); } - Ok(database_name) + Ok::<&str, Report>(database_name) } else { Ok(Default::default()) })?; - Ok(()) + let (schema_name, table_stmts) = match url.scheme() { + "mysql" => { + use sea_schema::mysql::discovery::SchemaDiscovery; + use sqlx::MySql; + + tracing::info!("Connecting to MySQL ..."); + let connection = sqlx_connect::( + config.sea_orm.max_connections, + config.sea_orm.acquire_timeout, + url.as_str(), + None, + ) + .await?; + + tracing::info!("Discovering schema ..."); + let schema_discovery = SchemaDiscovery::new(connection, database_name); + let schema = schema_discovery.discover().await?; + let table_stmts = schema + .tables + .into_iter() + .filter(|schema| filter_tables(&schema.info.name)) + .filter(|schema| filter_hidden_tables(&schema.info.name)) + .filter(|schema| filter_skip_tables(&schema.info.name)) + .map(|schema| schema.write()) + .collect(); + (None, table_stmts) + } + "sqlite" => { + use sea_schema::sqlite::discovery::SchemaDiscovery; + use sqlx::Sqlite; + + tracing::info!("Connecting to SQLite ..."); + let connection = sqlx_connect::( + config.sea_orm.max_connections, + config.sea_orm.acquire_timeout, + url.as_str(), + None, + ) + .await?; + + tracing::info!("Discovering schema ..."); + let schema_discovery = SchemaDiscovery::new(connection); + let schema = schema_discovery + .discover() + .await? + .merge_indexes_into_table(); + let table_stmts = schema + .tables + .into_iter() + .filter(|schema| filter_tables(&schema.name)) + .filter(|schema| filter_hidden_tables(&schema.name)) + .filter(|schema| filter_skip_tables(&schema.name)) + .map(|schema| schema.write()) + .collect(); + (None, table_stmts) + } + "postgres" | "potgresql" => { + use sea_schema::postgres::discovery::SchemaDiscovery; + use sqlx::Postgres; + + tracing::info!("Connecting to Postgres ..."); + let schema = config + .sea_orm + .database_schema + .as_deref() + .unwrap_or("public"); + let connection = sqlx_connect::( + config.sea_orm.max_connections, + config.sea_orm.acquire_timeout, + url.as_str(), + Some(schema), + ) + .await?; + tracing::info!("Discovering schema ..."); + let schema_discovery = SchemaDiscovery::new(connection, schema); + let schema = schema_discovery.discover().await?; + let table_stmts = schema + .tables + .into_iter() + .filter(|schema| filter_tables(&schema.info.name)) + .filter(|schema| filter_hidden_tables(&schema.info.name)) + .filter(|schema| filter_skip_tables(&schema.info.name)) + .map(|schema| schema.write()) + .collect(); + (config.sea_orm.database_schema, table_stmts) + } + _ => unimplemented!("{} is not supported", url.scheme()), + }; + tracing::info!("Schema discovered"); + + Ok((schema_name, table_stmts)) +} + +async fn sqlx_connect( + max_connections: u32, + acquire_timeout: u64, + url: &str, + schema: Option<&str>, +) -> Result> +where + DB: sqlx::Database, + for<'a> &'a mut ::Connection: sqlx::Executor<'a>, +{ + let mut pool_options = sqlx::pool::PoolOptions::::new() + .max_connections(max_connections) + .acquire_timeout(time::Duration::from_secs(acquire_timeout)); + // Set search_path for Postgres, E.g. Some("public") by default + // MySQL & SQLite connection initialize with schema `None` + if let Some(schema) = schema { + let sql = format!("SET search_path = '{schema}'"); + pool_options = pool_options.after_connect(move |conn, _| { + let sql = sql.clone(); + Box::pin(async move { + sqlx::Executor::execute(conn, sql.as_str()) + .await + .map(|_| ()) + }) + }); + } + pool_options.connect(url).await.map_err(Into::into) } diff --git a/src/main.rs b/src/main.rs index 1afd334..f257faf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,6 +22,9 @@ struct SeaOrmConfig { with_serde: bool, serde_skip_deserializing_primary_key: bool, serde_skip_hidden_column: bool, + max_connections: u32, + acquire_timeout: u64, + database_schema: Option, } #[derive(Deserialize, Serialize, Debug)] struct SeaOrmTableConfig { @@ -39,6 +42,9 @@ impl Default for Config { only: Vec::new(), exclude: Vec::new(), }, + database_schema: None, + max_connections: 10, + acquire_timeout: 30, expanded_format: false, include_hidden_tables: false, with_serde: false, @@ -67,6 +73,6 @@ async fn main() -> Result<()> { .extract()?; tracing::info!(?config); tracing::info!(?args); - generate::generate(args.database_url).await; + generate::get_tables(args.database_url, config).await?; Ok(()) }