diff --git a/Cargo.toml b/Cargo.toml index 9731bef..6e9f0dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,8 @@ edition = "2021" [dependencies] clap = { version = "4.5.32", features = ["derive", "env"] } color-eyre = "0.6.3" -figment = { version = "0.10.19", features = ["toml"] } +confique = { version = "0.3.0", features = ["yaml", "toml"] } +inquire = "0.7.5" sea-orm-codegen = "1.1.8" sea-schema = { version = "0.16.1", features = ["sqlx-all"] } serde = { version = "1.0.219", features = ["derive"] } diff --git a/src/generate.rs b/src/generate.rs index 57b9abc..1d89012 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -10,7 +10,7 @@ use url::Url; use crate::Config; pub async fn get_tables( database_url: String, - config: Config, + config: &Config, ) -> Result<(Option, Vec)> { let url = Url::parse(&database_url)?; @@ -51,7 +51,7 @@ pub async fn get_tables( use sea_schema::mysql::discovery::SchemaDiscovery; use sqlx::MySql; - tracing::info!("Connecting to MySQL ..."); + tracing::info!("Connecting to MySQL"); let connection = sqlx_connect::( config.sea_orm.max_connections, config.sea_orm.acquire_timeout, @@ -60,7 +60,7 @@ pub async fn get_tables( ) .await?; - tracing::info!("Discovering schema ..."); + tracing::info!("Discovering schema"); let schema_discovery = SchemaDiscovery::new(connection, database_name); let schema = schema_discovery.discover().await?; let table_stmts = schema @@ -77,7 +77,7 @@ pub async fn get_tables( use sea_schema::sqlite::discovery::SchemaDiscovery; use sqlx::Sqlite; - tracing::info!("Connecting to SQLite ..."); + tracing::info!("Connecting to SQLite"); let connection = sqlx_connect::( config.sea_orm.max_connections, config.sea_orm.acquire_timeout, @@ -86,7 +86,7 @@ pub async fn get_tables( ) .await?; - tracing::info!("Discovering schema ..."); + tracing::info!("Discovering schema"); let schema_discovery = SchemaDiscovery::new(connection); let schema = schema_discovery .discover() @@ -106,7 +106,7 @@ pub async fn get_tables( use sea_schema::postgres::discovery::SchemaDiscovery; use sqlx::Postgres; - tracing::info!("Connecting to Postgres ..."); + tracing::info!("Connecting to Postgres"); let schema = config .sea_orm .database_schema @@ -119,7 +119,7 @@ pub async fn get_tables( Some(schema), ) .await?; - tracing::info!("Discovering schema ..."); + tracing::info!("Discovering schema"); let schema_discovery = SchemaDiscovery::new(connection, schema); let schema = schema_discovery.discover().await?; let table_stmts = schema @@ -130,7 +130,7 @@ pub async fn get_tables( .filter(|schema| filter_skip_tables(&schema.info.name)) .map(|schema| schema.write()) .collect(); - (config.sea_orm.database_schema, table_stmts) + (config.sea_orm.database_schema.clone(), table_stmts) } _ => unimplemented!("{} is not supported", url.scheme()), }; diff --git a/src/main.rs b/src/main.rs index f257faf..26f0977 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,32 +1,53 @@ mod generate; -use std::{fs, path::PathBuf}; +use std::{fs, path::PathBuf, str::FromStr}; use clap::Parser; -use color_eyre::Result; +use color_eyre::{Report, Result}; use figment::{ providers::{Format, Serialized, Toml}, Figment, }; +use sea_orm_codegen::{ + DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, WithPrelude, + WithSerde, +}; use serde::{Deserialize, Serialize}; -#[derive(Deserialize, Serialize, Debug)] +#[derive(Deserialize, Serialize, Debug, Clone)] struct Config { sea_orm: SeaOrmConfig, } -#[derive(Deserialize, Serialize, Debug)] +#[derive(Deserialize, Serialize, Debug, Clone)] +enum DateTimeCrate { + Time, + Chrono, +} + +impl From for CodegenDateTimeCrate { + fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate { + match date_time_crate { + DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono, + DateTimeCrate::Time => CodegenDateTimeCrate::Time, + } + } +} + +#[derive(Deserialize, Serialize, Debug, Clone)] struct SeaOrmConfig { expanded_format: bool, table: SeaOrmTableConfig, - include_hidden_tables: bool, - with_serde: bool, + with_serde: String, + with_prelude: String, + with_copy_enums: bool, serde_skip_deserializing_primary_key: bool, serde_skip_hidden_column: bool, max_connections: u32, acquire_timeout: u64, database_schema: Option, + date_format: DateTimeCrate, } -#[derive(Deserialize, Serialize, Debug)] +#[derive(Deserialize, Serialize, Debug, Clone)] struct SeaOrmTableConfig { include_hidden: bool, only: Vec, @@ -46,15 +67,40 @@ impl Default for Config { max_connections: 10, acquire_timeout: 30, expanded_format: false, - include_hidden_tables: false, - with_serde: false, + with_copy_enums: false, + with_serde: String::from("none"), + with_prelude: String::from("none"), serde_skip_hidden_column: false, serde_skip_deserializing_primary_key: false, + date_format: DateTimeCrate::Time, }, } } } +impl TryInto for Config { + type Error = Report; + fn try_into(self) -> Result { + Ok(EntityWriterContext::new( + self.sea_orm.expanded_format, + WithPrelude::from_str(&self.sea_orm.with_prelude)?, + WithSerde::from_str(&self.sea_orm.with_serde)?, + self.sea_orm.with_copy_enums, + self.sea_orm.date_format.into(), + self.sea_orm.database_schema, + false, + self.sea_orm.serde_skip_deserializing_primary_key, + self.sea_orm.serde_skip_hidden_column, + Vec::new(), + Vec::new(), + Vec::new(), + Vec::new(), + false, + false, + )) + } +} + #[derive(Parser, Debug)] struct Args { #[clap(short, long, default_value = "generator.toml")] @@ -73,6 +119,8 @@ async fn main() -> Result<()> { .extract()?; tracing::info!(?config); tracing::info!(?args); - generate::get_tables(args.database_url, config).await?; + let (_, table_stmts) = generate::get_tables(args.database_url, &config).await?; + let writer_context = config.clone().try_into()?; + let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context); Ok(()) }