diff --git a/Cargo.toml b/Cargo.toml index ee4c70a..4638b6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" clap = { version = "4.5.32", features = ["derive", "env"] } color-eyre = "0.6.3" figment = { version = "0.10.19", features = ["yaml"] } +heck = "0.5.0" inquire = "0.7.5" sea-orm-codegen = "1.1.8" sea-schema = { version = "0.16.1", features = ["sqlx-all"] } diff --git a/src/config.rs b/src/config.rs index 2d92ebe..a475485 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,12 @@ -use sea_orm_codegen::DateTimeCrate as CodegenDateTimeCrate; +use std::path::PathBuf; + +use color_eyre::Report; +use sea_orm_codegen::{ + DateTimeCrate as CodegenDateTimeCrate, EntityWriterContext, WithPrelude, WithSerde, +}; use serde::{Deserialize, Deserializer, Serialize}; use serde_yaml::Value; -use tracing::info; +use tracing::instrument; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] @@ -98,6 +103,12 @@ impl Serialize for Prelude { pub struct Config { pub db: DbConfig, pub sea_orm: SeaOrmConfig, + pub output: OutputConfig, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct OutputConfig { + pub path: PathBuf, } #[derive(Deserialize, Serialize, Debug, Clone)] @@ -128,11 +139,13 @@ pub struct SeaOrmEntityConfig { pub extra_derives: SeaOrmExtraDerivesConfig, pub extra_attributes: SeaOrmExtraAttributesConfig, pub date_time_crate: DateTimeCrate, + pub with_copy_enums: bool, } #[derive(Deserialize, Serialize, Debug, Clone)] pub struct SeaOrmTableConfig { pub include_hidden: bool, + pub skip_seaql_migrations: bool, #[serde(flatten)] pub table_config: Option, } @@ -149,7 +162,8 @@ pub struct SeaOrmExtraAttributesConfig { pub eenum: Vec, } #[derive(Deserialize, Serialize, Debug, Clone)] -enum DateTimeCrate { +#[serde(rename_all = "lowercase")] +pub enum DateTimeCrate { Time, Chrono, } @@ -181,8 +195,13 @@ impl SeaOrmTableConfig { }) } } + } else if self.skip_seaql_migrations { + Box::new(move |table: &String| { + (include_hidden || !table.starts_with('_')) + && !table.starts_with("seaql_migrations") + }) } else { - Box::new(move |table: &String| include_hidden || !table.starts_with('_')) + Box::new(move |table: &String| (include_hidden || !table.starts_with('_'))) } } } @@ -206,6 +225,7 @@ impl Default for Config { format: EntityFormat::Compact, tables: SeaOrmTableConfig { include_hidden: false, + skip_seaql_migrations: true, table_config: None, }, extra_derives: SeaOrmExtraDerivesConfig { @@ -217,8 +237,60 @@ impl Default for Config { eenum: Vec::new(), }, date_time_crate: DateTimeCrate::Chrono, + with_copy_enums: false, }, }, + output: OutputConfig { + path: PathBuf::from("./entities"), + }, } } } + +impl EntityFormat { + pub fn is_expanded(&self) -> bool { + matches!(self, EntityFormat::Expanded) + } +} +impl From for WithPrelude { + fn from(val: Prelude) -> Self { + match val { + Prelude::Enabled => WithPrelude::All, + + Prelude::Disabled => WithPrelude::None, + Prelude::AllowUnusedImports => WithPrelude::AllAllowUnusedImports, + } + } +} +impl From for WithSerde { + fn from(val: SerdeEnable) -> Self { + match val { + SerdeEnable::Both => WithSerde::Both, + SerdeEnable::Serialize => WithSerde::Serialize, + SerdeEnable::Deserialize => WithSerde::Deserialize, + SerdeEnable::None => WithSerde::None, + } + } +} + +impl From for EntityWriterContext { + fn from(val: Config) -> Self { + EntityWriterContext::new( + val.sea_orm.entity.format.is_expanded(), + val.sea_orm.prelude.into(), + val.sea_orm.serde.enable.into(), + val.sea_orm.entity.with_copy_enums, + val.sea_orm.entity.date_time_crate.into(), + val.db.database_schema, + false, + val.sea_orm.serde.skip_deserializing_primary_key, + val.sea_orm.serde.skip_hidden_column, + val.sea_orm.entity.extra_derives.model, + val.sea_orm.entity.extra_attributes.model, + val.sea_orm.entity.extra_derives.eenum, + val.sea_orm.entity.extra_attributes.eenum, + false, + false, + ) + } +} diff --git a/src/generate.rs b/src/generate.rs index 8b05e58..ef2ad0c 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -1,13 +1,15 @@ use core::time; +use crate::Config; use color_eyre::{ eyre::{eyre, ContextCompat, Report}, Result, }; -use sea_schema::sea_query::TableCreateStatement; +use sea_orm_codegen::OutputFile; +use sea_schema::sea_query::{self, TableCreateStatement}; +use tokio::{fs, task::JoinSet}; use url::Url; -use crate::Config; pub async fn get_tables( database_url: String, config: &Config, @@ -18,6 +20,7 @@ pub async fn get_tables( let is_sqlite = url.scheme() == "sqlite"; let filter_tables = config.sea_orm.entity.tables.get_filter(); + // let filter_tables = |table: &String| -> bool { // config.sea_orm.entity.table.only.is_empty() || config.sea_orm.table.only.contains(table) // }; @@ -135,7 +138,6 @@ pub async fn get_tables( Ok((schema_name, table_stmts)) } - async fn sqlx_connect( max_connections: u32, acquire_timeout: u64, @@ -164,3 +166,53 @@ where } pool_options.connect(url).await.map_err(Into::into) } + +pub async fn generate_models( + tables: Vec, + config: Config, +) -> Result> { + tracing::info!(?tables); + let output_path = config.output.path; + let files = tables + .into_iter() + .map(|table| { + let output_path = output_path.clone(); + async move { + let table_name = match table.get_table_name() { + Some(table_ref) => match table_ref { + sea_query::TableRef::Table(t) + | sea_query::TableRef::SchemaTable(_, t) + | sea_query::TableRef::DatabaseSchemaTable(_, _, t) + | sea_query::TableRef::TableAlias(t, _) + | sea_query::TableRef::SchemaTableAlias(_, t, _) + | sea_query::TableRef::DatabaseSchemaTableAlias(_, _, t, _) => { + t.to_string() + } + _ => unimplemented!(), + }, + None => return Err(eyre!("Table name not found")), + }; + let file_path = output_path.join(&table_name); + let exists = file_path.exists(); + let content = match exists { + true => { + // let file_content = fs::read_to_string(path) + } + false => {} + }; + + Ok(OutputFile { + name: format!("{}.rs", table_name), + content: String::new(), + }) + } + }) + .collect::>>(); + let files = files + .join_all() + .await + .into_iter() + .collect::>>()?; + Ok(files) +} +// diff --git a/src/main.rs b/src/main.rs index e897db9..d68431e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,93 +1,21 @@ mod config; mod generate; -use std::{fs, path::PathBuf, str::FromStr}; +use std::{path::PathBuf, str::FromStr}; use clap::Parser; -use color_eyre::{Report, Result}; +use color_eyre::{eyre::eyre, Report, Result}; use config::Config; use figment::{ providers::{Format, Serialized, Yaml}, Figment, }; use sea_orm_codegen::{ - DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, WithPrelude, - WithSerde, + DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, OutputFile, + WithPrelude, WithSerde, }; use serde::{Deserialize, Serialize}; +use tokio::{fs, io::AsyncWriteExt, process::Command}; -// #[derive(Deserialize, Serialize, Debug, Clone)] -// struct Config { -// sea_orm: SeaOrmConfig, -// } - -// #[derive(Deserialize, Serialize, Debug, Clone)] -// struct SeaOrmConfig { -// expanded_format: bool, -// table: SeaOrmTableConfig, -// with_serde: String, -// with_prelude: String, -// with_copy_enums: bool, -// serde_skip_deserializing_primary_key: bool, -// serde_skip_hidden_column: bool, -// max_connections: u32, -// acquire_timeout: u64, -// database_schema: Option, -// date_format: DateTimeCrate, -// } -// #[derive(Deserialize, Serialize, Debug, Clone)] -// struct SeaOrmTableConfig { -// include_hidden: bool, -// only: Vec, -// exclude: Vec, -// } -// -// impl Default for Config { -// fn default() -> Self { -// Self { -// sea_orm: SeaOrmConfig { -// table: SeaOrmTableConfig { -// include_hidden: false, -// only: Vec::new(), -// exclude: Vec::new(), -// }, -// database_schema: None, -// max_connections: 10, -// acquire_timeout: 30, -// expanded_format: false, -// with_copy_enums: false, -// with_serde: String::from("none"), -// with_prelude: String::from("none"), -// serde_skip_hidden_column: false, -// serde_skip_deserializing_primary_key: false, -// date_format: DateTimeCrate::Time, -// }, -// } -// } -// } - -// impl TryInto for Config { -// type Error = Report; -// fn try_into(self) -> Result { -// Ok(EntityWriterContext::new( -// self.sea_orm.expanded_format, -// WithPrelude::from_str(&self.sea_orm.with_prelude)?, -// WithSerde::from_str(&self.sea_orm.with_serde)?, -// self.sea_orm.with_copy_enums, -// self.sea_orm.date_format.into(), -// self.sea_orm.database_schema, -// false, -// self.sea_orm.serde_skip_deserializing_primary_key, -// self.sea_orm.serde_skip_hidden_column, -// Vec::new(), -// Vec::new(), -// Vec::new(), -// Vec::new(), -// false, -// false, -// )) -// } -// } -// #[derive(Parser, Debug)] struct Args { #[clap(short, long, default_value = "generator.yml")] @@ -102,13 +30,41 @@ async fn main() -> Result<()> { let args = Args::parse(); let config: Config = Figment::new() - // .merge(Serialized::defaults(Config::default())) + .merge(Serialized::defaults(Config::default())) .merge(Yaml::file(&args.config)) .extract()?; tracing::info!(?config); tracing::info!(?args); - // let (_, table_stmts) = generate::get_tables(args.database_url, &config).await?; - // let writer_context = config.clone().try_into()?; - // let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context); + let output_dir = &config.output.path; + let output_internal_entities = output_dir.join("_entities"); + let (_, table_stmts) = generate::get_tables(args.database_url, &config).await?; + let writer_context = config.clone().into(); + let output = EntityTransformer::transform(table_stmts.clone())?.generate(&writer_context); + let mut files = output + .files + .into_iter() + .map(|OutputFile { name, content }| (output_internal_entities.join(name), content)) + .collect::>(); + let generate_files = generate::generate_models(table_stmts, config.clone()) + .await? + .into_iter() + .map(|OutputFile { name, content }| (output_dir.join(name), content)) + .collect::>(); + files.extend(generate_files); + tracing::info!("Generated {} files", files.len()); + fs::create_dir_all(&output_internal_entities).await?; + for (file_path, content) in files.iter() { + tracing::info!(?file_path, "Writing file"); + let mut file = fs::File::create(&file_path).await?; + file.write_all(content.as_bytes()).await?; + } + for (file_path, ..) in files.iter() { + tracing::info!(?file_path, "Running rustfmt"); + let exit_status = Command::new("rustfmt").arg(file_path).status().await?; // Get the status code + if !exit_status.success() { + // Propagate the error if any + return Err(eyre!("Failed to run rustfmt")); + } + } Ok(()) }