diff --git a/Cargo.lock b/Cargo.lock index 6b158e5..ea2e1af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1998,6 +1998,8 @@ dependencies = [ "sha2", "smallvec", "thiserror", + "tokio", + "tokio-stream", "tracing", "url", ] @@ -2037,6 +2039,7 @@ dependencies = [ "sqlx-sqlite", "syn", "tempfile", + "tokio", "url", ] @@ -2290,6 +2293,17 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml_datetime" version = "0.6.8" diff --git a/Cargo.toml b/Cargo.toml index 9d68394..0137e53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ serde = { version = "1.0.219", features = ["derive"] } serde-inline-default = "0.2.3" serde_json = "1.0.140" serde_yaml = "0.9.34" -sqlx = { version = "0.8.3", features = ["mysql", "postgres", "sqlite"] } +sqlx = { version = "0.8.3", features = ["mysql", "postgres", "sqlite", "runtime-tokio"] } syn = { version = "2.0.100", features = ["extra-traits", "full"] } tokio = { version = "1.44.1", features = ["full"] } toml_edit = { version = "0.22.24", features = ["serde"] } diff --git a/generator.toml b/generator.toml index 304942a..6319488 100644 --- a/generator.toml +++ b/generator.toml @@ -1,6 +1,15 @@ # This file is used to configure the SeaORM generator. [modules.discovery] enable = true +[modules.discovery.filter] +include_hidden = false +skip_seaql_migrations = true [modules.sea_orm] enable = true +prelude = true +[modules.sea_orm.serde] +enable = true +skip_deserializing_primary_key = false +skip_hidden_column = false +[modules.sea_orm.entity] diff --git a/src/generator/mod.rs b/src/generator/mod.rs index e6d2c05..ec7918c 100644 --- a/src/generator/mod.rs +++ b/src/generator/mod.rs @@ -20,6 +20,7 @@ pub async fn generate( ctx.get_anymap_mut() .insert(DatabaseUrl(database_url.to_owned())); module_manager.validate().await?; + module_manager.execute().await?; // let db_filter = config.sea_orm.entity.tables.get_filter(); // let (table_stmts, db_type) = diff --git a/src/generator/modules/discovery/db.rs b/src/generator/modules/discovery/db.rs new file mode 100644 index 0000000..0e93e88 --- /dev/null +++ b/src/generator/modules/discovery/db.rs @@ -0,0 +1,157 @@ +use core::time; + +use color_eyre::eyre::{eyre, ContextCompat, Report, Result}; +use sea_schema::sea_query::TableCreateStatement; +use url::Url; + +use crate::config::db::DbConfig; + +use super::DiscoveryConfig; +#[derive(Debug, Clone)] +pub enum DbType { + MySql, + Postgres, + Sqlite, +} + +pub async fn get_tables( + database_url: String, + database_config: &DiscoveryConfig, +) -> Result<(Vec, DbType)> { + let url = Url::parse(&database_url)?; + + tracing::trace!(?url); + + let is_sqlite = url.scheme() == "sqlite"; + + let database_name: &str = (if !is_sqlite { + let database_name = url + .path_segments() + .context("No database name as part of path")? + .next() + .context("No database name as part of path")?; + + if database_name.is_empty() { + return Err(eyre!("Database path name is empty")); + } + Ok::<&str, Report>(database_name) + } else { + Ok(Default::default()) + })?; + + let filter = database_config.filter.clone().get_filter(); + + let (table_stmts, db_type) = match url.scheme() { + "mysql" => { + use sea_schema::mysql::discovery::SchemaDiscovery; + use sqlx::MySql; + + tracing::info!("Connecting to MySQL"); + let connection = sqlx_connect::( + database_config.max_connections, + database_config.acquire_timeout, + url.as_str(), + None, + ) + .await?; + + tracing::info!("Discovering schema"); + let schema_discovery = SchemaDiscovery::new(connection, database_name); + let schema = schema_discovery.discover().await?; + let table_stmts = schema + .tables + .into_iter() + .filter(|schema| filter(&schema.info.name)) + .map(|schema| schema.write()) + .collect(); + (table_stmts, DbType::MySql) + } + "sqlite" => { + use sea_schema::sqlite::discovery::SchemaDiscovery; + use sqlx::Sqlite; + + tracing::info!("Connecting to SQLite"); + let connection = sqlx_connect::( + database_config.max_connections, + database_config.acquire_timeout, + url.as_str(), + None, + ) + .await?; + + tracing::info!("Discovering schema"); + let schema_discovery = SchemaDiscovery::new(connection); + let schema = schema_discovery + .discover() + .await? + .merge_indexes_into_table(); + let table_stmts = schema + .tables + .into_iter() + .filter(|schema| filter(&schema.name)) + .map(|schema| schema.write()) + .collect(); + (table_stmts, DbType::Sqlite) + } + "postgres" | "potgresql" => { + use sea_schema::postgres::discovery::SchemaDiscovery; + use sqlx::Postgres; + + tracing::info!("Connecting to Postgres"); + let schema = &database_config + .database_schema + .as_deref() + .unwrap_or("public"); + let connection = sqlx_connect::( + database_config.max_connections, + database_config.acquire_timeout, + url.as_str(), + Some(schema), + ) + .await?; + tracing::info!("Discovering schema"); + let schema_discovery = SchemaDiscovery::new(connection, schema); + let schema = schema_discovery.discover().await?; + tracing::info!(?schema); + let table_stmts = schema + .tables + .into_iter() + .filter(|schema| filter(&schema.info.name)) + .map(|schema| schema.write()) + .collect(); + (table_stmts, DbType::Postgres) + } + _ => unimplemented!("{} is not supported", url.scheme()), + }; + tracing::info!("Schema discovered"); + + Ok((table_stmts, db_type)) +} +async fn sqlx_connect( + max_connections: u32, + acquire_timeout: u64, + url: &str, + schema: Option<&str>, +) -> Result> +where + DB: sqlx::Database, + for<'a> &'a mut ::Connection: sqlx::Executor<'a>, +{ + let mut pool_options = sqlx::pool::PoolOptions::::new() + .max_connections(max_connections) + .acquire_timeout(time::Duration::from_secs(acquire_timeout)); + // Set search_path for Postgres, E.g. Some("public") by default + // MySQL & SQLite connection initialize with schema `None` + if let Some(schema) = schema { + let sql = format!("SET search_path = '{schema}'"); + pool_options = pool_options.after_connect(move |conn, _| { + let sql = sql.clone(); + Box::pin(async move { + sqlx::Executor::execute(conn, sql.as_str()) + .await + .map(|_| ()) + }) + }); + } + pool_options.connect(url).await.map_err(Into::into) +} diff --git a/src/generator/modules/discovery/mod.rs b/src/generator/modules/discovery/mod.rs index aaf5028..08ab402 100644 --- a/src/generator/modules/discovery/mod.rs +++ b/src/generator/modules/discovery/mod.rs @@ -1,3 +1,4 @@ +pub mod db; use crate::generator::DatabaseUrl; use super::{Module, ModulesContext}; @@ -5,28 +6,97 @@ use color_eyre::Result; use serde::Deserialize; use serde_inline_default::serde_inline_default; -#[serde_inline_default] #[derive(Debug, Clone, Deserialize)] +#[serde(default)] pub struct DiscoveryConfig { - #[serde_inline_default(false)] pub enable: bool, - #[serde_inline_default(None)] pub database_schema: Option, - #[serde_inline_default(10)] pub max_connections: u32, - #[serde_inline_default(30)] - pub acquire_timeout: u32, + pub acquire_timeout: u64, + pub filter: DiscoveryFilterConfig, } +impl Default for DiscoveryConfig { + fn default() -> Self { + Self { + enable: false, + database_schema: None, + max_connections: 10, + acquire_timeout: 30, + filter: DiscoveryFilterConfig::default(), + } + } +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub struct DiscoveryFilterConfig { + pub include_hidden: bool, + pub skip_seaql_migrations: bool, + #[serde(flatten)] + pub table: Option, +} + +impl Default for DiscoveryFilterConfig { + fn default() -> Self { + Self { + include_hidden: false, + skip_seaql_migrations: true, + table: None, + } + } +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "snake_case")] +#[serde(untagged)] +pub enum TableConfig { + Specific { only: Vec }, + Exclude { exclude: Vec }, +} + +impl DiscoveryFilterConfig { + pub fn get_filter(&self) -> Box bool + Send> { + let include_hidden = self.include_hidden; + if let Some(table) = &self.table { + match table { + TableConfig::Specific { only } => { + let specific = only.clone(); + Box::new(move |table: &String| { + (include_hidden || !table.starts_with('_')) && specific.contains(table) + }) + } + TableConfig::Exclude { exclude } => { + let exclude = exclude.clone(); + Box::new(move |table: &String| { + (include_hidden || !table.starts_with('_')) && !exclude.contains(table) + }) + } + } + } else if self.skip_seaql_migrations { + Box::new(move |table: &String| { + (include_hidden || !table.starts_with('_')) + && !table.starts_with("seaql_migrations") + }) + } else { + Box::new(move |table: &String| (include_hidden || !table.starts_with('_'))) + } + } +} + +pub struct DiscoveredSchema { + pub tables: Vec, +} + #[derive(Debug)] pub struct DiscoveryModule; #[async_trait::async_trait] impl Module for DiscoveryModule { - fn init(&self, ctx: &mut ModulesContext) -> Result<()> { + fn init(&mut self, ctx: &mut ModulesContext) -> Result<()> { ctx.get_config_auto::("modules.discovery")?; Ok(()) } - async fn validate(&self, ctx: &mut ModulesContext) -> Result { + async fn validate(&mut self, ctx: &mut ModulesContext) -> Result { let map = ctx.get_anymap(); if let (Some(config), Some(_)) = (map.get::(), map.get::()) { @@ -36,4 +106,17 @@ impl Module for DiscoveryModule { Ok(false) } } + async fn execute(&mut self, ctx: &mut ModulesContext) -> Result<()> { + if let (Some(config), Some(url)) = ( + ctx.get_anymap().get::(), + ctx.get_anymap().get::(), + ) { + let url = url.clone(); + + let (stmts, db_type) = db::get_tables(url.0, config).await?; + tracing::info!(?stmts, ?db_type); + // db::generate(ctx).await?; + } + Ok(()) + } } diff --git a/src/generator/modules/mod.rs b/src/generator/modules/mod.rs index 7b84b6a..2417274 100644 --- a/src/generator/modules/mod.rs +++ b/src/generator/modules/mod.rs @@ -58,7 +58,7 @@ impl ModulesContext { Err(eyre!("Config not found")) } } - pub fn get_config<'a, V: Deserialize<'a>>(&self, path: &str) -> Result { + pub fn get_config<'a, V: Deserialize<'a> + Debug>(&self, path: &str) -> Result { let item = self.get_config_raw(path)?; let value = item .clone() @@ -66,9 +66,10 @@ impl ModulesContext { .map_err(|_| eyre!("Config not found"))?; let deserializer = value.into_deserializer(); let config = V::deserialize(deserializer)?; + tracing::debug!(?config, "{}", path); Ok(config) } - pub fn get_config_auto<'a, V: Deserialize<'a> + Clone + Send + 'static>( + pub fn get_config_auto<'a, V: Deserialize<'a> + Clone + Send + Debug + 'static>( &mut self, path: &str, ) -> Result<()> { @@ -89,8 +90,9 @@ impl ModulesContext { } #[async_trait::async_trait] pub trait Module: Debug { - fn init(&self, ctx: &mut ModulesContext) -> Result<()>; - async fn validate(&self, ctx: &mut ModulesContext) -> Result; + fn init(&mut self, ctx: &mut ModulesContext) -> Result<()>; + async fn validate(&mut self, ctx: &mut ModulesContext) -> Result; + async fn execute(&mut self, ctx: &mut ModulesContext) -> Result<()>; } pub struct ModuleManager { @@ -115,7 +117,7 @@ impl ModuleManager { } pub fn init(&mut self) -> Result<()> { - for module in &self.modules { + for module in &mut self.modules { module.init(&mut self.ctx)?; } Ok(()) @@ -123,10 +125,10 @@ impl ModuleManager { pub async fn validate(&mut self) -> Result<()> { let mut index_wr = 0usize; for index in 0..self.modules.len() { - let module = &self.modules[index]; + let module = &mut self.modules[index]; let enabled = module.validate(&mut self.ctx).await?; tracing::info!(?module, ?enabled); - if !enabled { + if enabled { self.modules.swap(index_wr, index); index_wr += 1; } @@ -134,4 +136,11 @@ impl ModuleManager { self.modules.truncate(index_wr); Ok(()) } + pub async fn execute(&mut self) -> Result<()> { + for module in &mut self.modules { + tracing::debug!(?module, "executing"); + module.execute(&mut self.ctx).await?; + } + Ok(()) + } } diff --git a/src/generator/modules/sea_orm/mod.rs b/src/generator/modules/sea_orm/mod.rs index c2d493d..aaad73a 100644 --- a/src/generator/modules/sea_orm/mod.rs +++ b/src/generator/modules/sea_orm/mod.rs @@ -22,11 +22,11 @@ pub struct SeaOrmModule; #[async_trait::async_trait] impl Module for SeaOrmModule { - fn init(&self, ctx: &mut ModulesContext) -> Result<()> { + fn init(&mut self, ctx: &mut ModulesContext) -> Result<()> { ctx.get_config_auto::("modules.sea_orm")?; Ok(()) } - async fn validate(&self, ctx: &mut ModulesContext) -> Result { + async fn validate(&mut self, ctx: &mut ModulesContext) -> Result { let map = ctx.get_anymap(); if let (Some(config_discovery_config), Some(_), Some(config_sea_orm)) = ( @@ -43,4 +43,7 @@ impl Module for SeaOrmModule { Ok(false) } } + async fn execute(&mut self, ctx: &mut ModulesContext) -> Result<()> { + Ok(()) + } } diff --git a/src/generator/modules/templates/mod.rs b/src/generator/modules/templates/mod.rs index c4d135c..7cfb1b6 100644 --- a/src/generator/modules/templates/mod.rs +++ b/src/generator/modules/templates/mod.rs @@ -17,13 +17,12 @@ pub struct TemplateModule; #[async_trait::async_trait] impl Module for TemplateModule { - fn init(&self, ctx: &mut ModulesContext) -> Result<()> { + fn init(&mut self, ctx: &mut ModulesContext) -> Result<()> { let registry: Handlebars<'static> = Handlebars::new(); ctx.get_anymap_mut().insert(registry); - // ctx.get_config_auto::("modules.discovery")?; Ok(()) } - async fn validate(&self, ctx: &mut ModulesContext) -> Result { + async fn validate(&mut self, ctx: &mut ModulesContext) -> Result { // let map = ctx.get_anymap(); // // if let (Some(config), Some(_)) = (map.get::(), map.get::()) { @@ -32,6 +31,9 @@ impl Module for TemplateModule { // // One or both keys are missing // Ok(false) // } - Ok(false) + Ok(true) + } + async fn execute(&mut self, ctx: &mut ModulesContext) -> Result<()> { + Ok(()) } }