From dc9695dac0fd6f8d54c83d339e6323daf39de648 Mon Sep 17 00:00:00 2001 From: Nikkuss Date: Tue, 1 Apr 2025 21:33:03 +0400 Subject: [PATCH] rework --- Cargo.toml | 3 +- src/config.rs | 224 ++++++++++++++++++++++++++++++++++++++++++++++++ src/generate.rs | 61 +++++++------ src/main.rs | 173 +++++++++++++++++-------------------- 4 files changed, 335 insertions(+), 126 deletions(-) create mode 100644 src/config.rs diff --git a/Cargo.toml b/Cargo.toml index 6e9f0dc..ee4c70a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,11 +6,12 @@ edition = "2021" [dependencies] clap = { version = "4.5.32", features = ["derive", "env"] } color-eyre = "0.6.3" -confique = { version = "0.3.0", features = ["yaml", "toml"] } +figment = { version = "0.10.19", features = ["yaml"] } inquire = "0.7.5" sea-orm-codegen = "1.1.8" sea-schema = { version = "0.16.1", features = ["sqlx-all"] } serde = { version = "1.0.219", features = ["derive"] } +serde_yaml = "0.9.34" sqlx = { version = "0.8.3", features = ["mysql", "postgres", "sqlite"] } syn = { version = "2.0.100", features = ["extra-traits", "full"] } tokio = { version = "1.44.1", features = ["full"] } diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..2d92ebe --- /dev/null +++ b/src/config.rs @@ -0,0 +1,224 @@ +use sea_orm_codegen::DateTimeCrate as CodegenDateTimeCrate; +use serde::{Deserialize, Deserializer, Serialize}; +use serde_yaml::Value; +use tracing::info; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum EntityFormat { + Expanded, + Compact, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +#[serde(untagged)] +pub enum TableConfig { + Specific { specific: Vec }, + Exclude { exclude: Vec }, +} + +#[derive(Debug, Clone)] +pub enum SerdeEnable { + Both, + Serialize, + Deserialize, + None, +} + +#[derive(Debug, Clone)] +pub enum Prelude { + Enabled, + Disabled, + AllowUnusedImports, +} + +impl<'de> Deserialize<'de> for SerdeEnable { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::String(s) if s == "serialize" => Ok(SerdeEnable::Serialize), + Value::String(s) if s == "deserialize" => Ok(SerdeEnable::Deserialize), + Value::Bool(true) => Ok(SerdeEnable::Both), + Value::Bool(false) => Ok(SerdeEnable::None), + _ => Err(serde::de::Error::custom( + "expected 'serialize', 'deserialize', 'true' or 'false'", + )), + } + } +} +impl Serialize for SerdeEnable { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + SerdeEnable::Both => serializer.serialize_bool(true), + SerdeEnable::Serialize => serializer.serialize_str("serialize"), + SerdeEnable::Deserialize => serializer.serialize_str("deserialize"), + SerdeEnable::None => serializer.serialize_bool(false), + } + } +} +impl<'de> Deserialize<'de> for Prelude { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::Bool(true) => Ok(Prelude::Enabled), + Value::Bool(false) => Ok(Prelude::Disabled), + Value::String(s) if s == "allow_unused_imports" => Ok(Prelude::AllowUnusedImports), + _ => Err(serde::de::Error::custom( + "expected 'true', 'false', or 'allow_unused_imports'", + )), + } + } +} +impl Serialize for Prelude { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Prelude::Enabled => serializer.serialize_bool(true), + Prelude::Disabled => serializer.serialize_bool(false), + Prelude::AllowUnusedImports => serializer.serialize_str("allow_unused_imports"), + } + } +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct Config { + pub db: DbConfig, + pub sea_orm: SeaOrmConfig, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct DbConfig { + pub database_schema: Option, + pub max_connections: u32, + pub acquire_timeout: u64, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmConfig { + pub prelude: Prelude, + pub serde: SeaOrmSerdeConfig, + pub entity: SeaOrmEntityConfig, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmSerdeConfig { + pub enable: SerdeEnable, + pub skip_deserializing_primary_key: bool, + pub skip_hidden_column: bool, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmEntityConfig { + pub format: EntityFormat, + pub tables: SeaOrmTableConfig, + pub extra_derives: SeaOrmExtraDerivesConfig, + pub extra_attributes: SeaOrmExtraAttributesConfig, + pub date_time_crate: DateTimeCrate, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmTableConfig { + pub include_hidden: bool, + #[serde(flatten)] + pub table_config: Option, +} +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmExtraDerivesConfig { + pub model: Vec, + #[serde(rename = "enum")] + pub eenum: Vec, +} +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmExtraAttributesConfig { + pub model: Vec, + #[serde(rename = "enum")] + pub eenum: Vec, +} +#[derive(Deserialize, Serialize, Debug, Clone)] +enum DateTimeCrate { + Time, + Chrono, +} + +impl From for CodegenDateTimeCrate { + fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate { + match date_time_crate { + DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono, + DateTimeCrate::Time => CodegenDateTimeCrate::Time, + } + } +} + +impl SeaOrmTableConfig { + pub fn get_filter(&self) -> Box bool> { + let include_hidden = self.include_hidden; + if let Some(table) = &self.table_config { + match table { + TableConfig::Specific { specific } => { + let specific = specific.clone(); + Box::new(move |table: &String| { + (include_hidden || !table.starts_with('_')) && specific.contains(table) + }) + } + TableConfig::Exclude { exclude } => { + let exclude = exclude.clone(); + Box::new(move |table: &String| { + (include_hidden || !table.starts_with('_')) && !exclude.contains(table) + }) + } + } + } else { + Box::new(move |table: &String| include_hidden || !table.starts_with('_')) + } + } +} + +impl Default for Config { + fn default() -> Self { + Self { + db: DbConfig { + database_schema: None, + max_connections: 10, + acquire_timeout: 5, + }, + sea_orm: SeaOrmConfig { + prelude: Prelude::Enabled, + serde: SeaOrmSerdeConfig { + enable: SerdeEnable::None, + skip_deserializing_primary_key: false, + skip_hidden_column: false, + }, + entity: SeaOrmEntityConfig { + format: EntityFormat::Compact, + tables: SeaOrmTableConfig { + include_hidden: false, + table_config: None, + }, + extra_derives: SeaOrmExtraDerivesConfig { + model: Vec::new(), + eenum: Vec::new(), + }, + extra_attributes: SeaOrmExtraAttributesConfig { + model: Vec::new(), + eenum: Vec::new(), + }, + date_time_crate: DateTimeCrate::Chrono, + }, + }, + } + } +} diff --git a/src/generate.rs b/src/generate.rs index 1d89012..8b05e58 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -17,20 +17,21 @@ pub async fn get_tables( tracing::trace!(?url); let is_sqlite = url.scheme() == "sqlite"; - let filter_tables = |table: &String| -> bool { - config.sea_orm.table.only.is_empty() || config.sea_orm.table.only.contains(table) - }; - - let filter_hidden_tables = |table: &str| -> bool { - if false { - true - } else { - !table.starts_with('_') - } - }; - - let filter_skip_tables = - |table: &String| -> bool { !config.sea_orm.table.exclude.contains(table) }; + let filter_tables = config.sea_orm.entity.tables.get_filter(); + // let filter_tables = |table: &String| -> bool { + // config.sea_orm.entity.table.only.is_empty() || config.sea_orm.table.only.contains(table) + // }; + // + // let filter_hidden_tables = |table: &str| -> bool { + // if false { + // true + // } else { + // !table.starts_with('_') + // } + // }; + // + // let filter_skip_tables = + // |table: &String| -> bool { !config.sea_orm.table.exclude.contains(table) }; let database_name: &str = (if !is_sqlite { let database_name = url .path_segments() @@ -53,8 +54,8 @@ pub async fn get_tables( tracing::info!("Connecting to MySQL"); let connection = sqlx_connect::( - config.sea_orm.max_connections, - config.sea_orm.acquire_timeout, + config.db.max_connections, + config.db.acquire_timeout, url.as_str(), None, ) @@ -67,8 +68,8 @@ pub async fn get_tables( .tables .into_iter() .filter(|schema| filter_tables(&schema.info.name)) - .filter(|schema| filter_hidden_tables(&schema.info.name)) - .filter(|schema| filter_skip_tables(&schema.info.name)) + // .filter(|schema| filter_hidden_tables(&schema.info.name)) + // .filter(|schema| filter_skip_tables(&schema.info.name)) .map(|schema| schema.write()) .collect(); (None, table_stmts) @@ -79,8 +80,8 @@ pub async fn get_tables( tracing::info!("Connecting to SQLite"); let connection = sqlx_connect::( - config.sea_orm.max_connections, - config.sea_orm.acquire_timeout, + config.db.max_connections, + config.db.acquire_timeout, url.as_str(), None, ) @@ -96,8 +97,8 @@ pub async fn get_tables( .tables .into_iter() .filter(|schema| filter_tables(&schema.name)) - .filter(|schema| filter_hidden_tables(&schema.name)) - .filter(|schema| filter_skip_tables(&schema.name)) + // .filter(|schema| filter_hidden_tables(&schema.name)) + // .filter(|schema| filter_skip_tables(&schema.name)) .map(|schema| schema.write()) .collect(); (None, table_stmts) @@ -107,14 +108,10 @@ pub async fn get_tables( use sqlx::Postgres; tracing::info!("Connecting to Postgres"); - let schema = config - .sea_orm - .database_schema - .as_deref() - .unwrap_or("public"); + let schema = &config.db.database_schema.as_deref().unwrap_or("public"); let connection = sqlx_connect::( - config.sea_orm.max_connections, - config.sea_orm.acquire_timeout, + config.db.max_connections, + config.db.acquire_timeout, url.as_str(), Some(schema), ) @@ -126,11 +123,11 @@ pub async fn get_tables( .tables .into_iter() .filter(|schema| filter_tables(&schema.info.name)) - .filter(|schema| filter_hidden_tables(&schema.info.name)) - .filter(|schema| filter_skip_tables(&schema.info.name)) + // .filter(|schema| filter_hidden_tables(&schema.info.name)) + // .filter(|schema| filter_skip_tables(&schema.info.name)) .map(|schema| schema.write()) .collect(); - (config.sea_orm.database_schema.clone(), table_stmts) + (config.db.database_schema.clone(), table_stmts) } _ => unimplemented!("{} is not supported", url.scheme()), }; diff --git a/src/main.rs b/src/main.rs index adfd1a4..e897db9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,12 @@ +mod config; mod generate; use std::{fs, path::PathBuf, str::FromStr}; use clap::Parser; use color_eyre::{Report, Result}; +use config::Config; use figment::{ - providers::{Format, Serialized, Toml}, + providers::{Format, Serialized, Yaml}, Figment, }; use sea_orm_codegen::{ @@ -13,97 +15,82 @@ use sea_orm_codegen::{ }; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Serialize, Debug, Clone)] -struct Config { - sea_orm: SeaOrmConfig, -} +// #[derive(Deserialize, Serialize, Debug, Clone)] +// struct Config { +// sea_orm: SeaOrmConfig, +// } -#[derive(Deserialize, Serialize, Debug, Clone)] -enum DateTimeCrate { - Time, - Chrono, -} - -impl From for CodegenDateTimeCrate { - fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate { - match date_time_crate { - DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono, - DateTimeCrate::Time => CodegenDateTimeCrate::Time, - } - } -} - -#[derive(Deserialize, Serialize, Debug, Clone)] -struct SeaOrmConfig { - expanded_format: bool, - table: SeaOrmTableConfig, - with_serde: String, - with_prelude: String, - with_copy_enums: bool, - serde_skip_deserializing_primary_key: bool, - serde_skip_hidden_column: bool, - max_connections: u32, - acquire_timeout: u64, - database_schema: Option, - date_format: DateTimeCrate, -} -#[derive(Deserialize, Serialize, Debug, Clone)] -struct SeaOrmTableConfig { - include_hidden: bool, - only: Vec, - exclude: Vec, -} - -impl Default for Config { - fn default() -> Self { - Self { - sea_orm: SeaOrmConfig { - table: SeaOrmTableConfig { - include_hidden: false, - only: Vec::new(), - exclude: Vec::new(), - }, - database_schema: None, - max_connections: 10, - acquire_timeout: 30, - expanded_format: false, - with_copy_enums: false, - with_serde: String::from("none"), - with_prelude: String::from("none"), - serde_skip_hidden_column: false, - serde_skip_deserializing_primary_key: false, - date_format: DateTimeCrate::Time, - }, - } - } -} - -impl TryInto for Config { - type Error = Report; - fn try_into(self) -> Result { - Ok(EntityWriterContext::new( - self.sea_orm.expanded_format, - WithPrelude::from_str(&self.sea_orm.with_prelude)?, - WithSerde::from_str(&self.sea_orm.with_serde)?, - self.sea_orm.with_copy_enums, - self.sea_orm.date_format.into(), - self.sea_orm.database_schema, - false, - self.sea_orm.serde_skip_deserializing_primary_key, - self.sea_orm.serde_skip_hidden_column, - Vec::new(), - Vec::new(), - Vec::new(), - Vec::new(), - false, - false, - )) - } -} +// #[derive(Deserialize, Serialize, Debug, Clone)] +// struct SeaOrmConfig { +// expanded_format: bool, +// table: SeaOrmTableConfig, +// with_serde: String, +// with_prelude: String, +// with_copy_enums: bool, +// serde_skip_deserializing_primary_key: bool, +// serde_skip_hidden_column: bool, +// max_connections: u32, +// acquire_timeout: u64, +// database_schema: Option, +// date_format: DateTimeCrate, +// } +// #[derive(Deserialize, Serialize, Debug, Clone)] +// struct SeaOrmTableConfig { +// include_hidden: bool, +// only: Vec, +// exclude: Vec, +// } +// +// impl Default for Config { +// fn default() -> Self { +// Self { +// sea_orm: SeaOrmConfig { +// table: SeaOrmTableConfig { +// include_hidden: false, +// only: Vec::new(), +// exclude: Vec::new(), +// }, +// database_schema: None, +// max_connections: 10, +// acquire_timeout: 30, +// expanded_format: false, +// with_copy_enums: false, +// with_serde: String::from("none"), +// with_prelude: String::from("none"), +// serde_skip_hidden_column: false, +// serde_skip_deserializing_primary_key: false, +// date_format: DateTimeCrate::Time, +// }, +// } +// } +// } +// impl TryInto for Config { +// type Error = Report; +// fn try_into(self) -> Result { +// Ok(EntityWriterContext::new( +// self.sea_orm.expanded_format, +// WithPrelude::from_str(&self.sea_orm.with_prelude)?, +// WithSerde::from_str(&self.sea_orm.with_serde)?, +// self.sea_orm.with_copy_enums, +// self.sea_orm.date_format.into(), +// self.sea_orm.database_schema, +// false, +// self.sea_orm.serde_skip_deserializing_primary_key, +// self.sea_orm.serde_skip_hidden_column, +// Vec::new(), +// Vec::new(), +// Vec::new(), +// Vec::new(), +// false, +// false, +// )) +// } +// } +// #[derive(Parser, Debug)] struct Args { - #[clap(short, long, default_value = "generator.toml")] + #[clap(short, long, default_value = "generator.yml")] config: String, #[clap(short, long, env = "DATABASE_URL")] database_url: String, @@ -115,13 +102,13 @@ async fn main() -> Result<()> { let args = Args::parse(); let config: Config = Figment::new() - .merge(Serialized::defaults(Config::default())) - .merge(Toml::file(&args.config)) + // .merge(Serialized::defaults(Config::default())) + .merge(Yaml::file(&args.config)) .extract()?; tracing::info!(?config); tracing::info!(?args); - let (_, table_stmts) = generate::get_tables(args.database_url, &config).await?; - let writer_context = config.clone().try_into()?; - let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context); + // let (_, table_stmts) = generate::get_tables(args.database_url, &config).await?; + // let writer_context = config.clone().try_into()?; + // let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context); Ok(()) }