diff --git a/Cargo.lock b/Cargo.lock index ea2e1af..1c0ec74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,15 +108,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "atomic" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d818003e740b63afc82337e3160717f4f63078720a810b7b903e70a5d1d2994" -dependencies = [ - "bytemuck", -] - [[package]] name = "autocfg" version = "1.4.0" @@ -180,12 +171,6 @@ version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" -[[package]] -name = "bytemuck" -version = "1.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540" - [[package]] name = "byteorder" version = "1.5.0" @@ -595,19 +580,6 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" -[[package]] -name = "figment" -version = "0.10.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cb01cd46b0cf372153850f4c6c272d9cbea2da513e07538405148f95bd789f3" -dependencies = [ - "atomic", - "serde", - "serde_yaml", - "uncased", - "version_check", -] - [[package]] name = "flume" version = "0.11.1" @@ -1379,6 +1351,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17359afc20d7ab31fdb42bb844c8b3bb1dabd7dcf7e68428492da7f16966fcef" +[[package]] +name = "pathdiff" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -1675,23 +1653,22 @@ dependencies = [ "color-eyre", "comfy-table", "comment-parser", - "figment", "handlebars", "heck 0.5.0", "include_dir", "indicatif", "inquire", "path-clean", + "pathdiff", "quote", "sea-orm-codegen", "sea-schema", "serde", "serde-inline-default", - "serde_json", - "serde_yaml", "sqlx", "syn", "tokio", + "toml", "toml_edit", "tracing", "tracing-subscriber", @@ -1821,19 +1798,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_yaml" -version = "0.9.34+deprecated" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" -dependencies = [ - "indexmap", - "itoa", - "ryu", - "serde", - "unsafe-libyaml", -] - [[package]] name = "sha1" version = "0.10.6" @@ -2304,6 +2268,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + [[package]] name = "toml_datetime" version = "0.6.8" @@ -2410,15 +2386,6 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" -[[package]] -name = "uncased" -version = "0.9.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b88fcfe09e89d3866a5c11019378088af2d24c3fbd4f0543f96b479ec90697" -dependencies = [ - "version_check", -] - [[package]] name = "unicode-bidi" version = "0.3.18" @@ -2464,12 +2431,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" -[[package]] -name = "unsafe-libyaml" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" - [[package]] name = "url" version = "2.5.4" @@ -2787,9 +2748,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e97b544156e9bebe1a0ffbc03484fc1ffe3100cbce3ffb17eac35f7cdd7ab36" +checksum = "63d3fcd9bba44b03821e7d699eeee959f3126dcc4aa8e4ae18ec617c2a5cea10" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 0137e53..10d6168 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,23 +10,22 @@ clap = { version = "4.5.32", features = ["derive", "env"] } color-eyre = "0.6.3" comfy-table = { version = "7.1.4", default-features = false } comment-parser = "0.1.0" -figment = { version = "0.10.19", features = ["yaml"] } handlebars = "6.3.2" heck = "0.5.0" include_dir = "0.7.4" indicatif = "0.17.11" inquire = "0.7.5" path-clean = "1.0.1" +pathdiff = "0.2.3" quote = "1.0.40" sea-orm-codegen = "1.1.8" sea-schema = { version = "0.16.1", features = ["sqlx-all"] } serde = { version = "1.0.219", features = ["derive"] } serde-inline-default = "0.2.3" -serde_json = "1.0.140" -serde_yaml = "0.9.34" sqlx = { version = "0.8.3", features = ["mysql", "postgres", "sqlite", "runtime-tokio"] } syn = { version = "2.0.100", features = ["extra-traits", "full"] } tokio = { version = "1.44.1", features = ["full"] } +toml = "0.8.20" toml_edit = { version = "0.22.24", features = ["serde"] } tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } diff --git a/generator.toml b/generator.toml index 9ca21c4..1ea48e0 100644 --- a/generator.toml +++ b/generator.toml @@ -1,10 +1,19 @@ # This file is used to configure the SeaORM generator. [modules.discovery] enable = true -[modules.discovery.filter] -include_hidden = false -skip_seaql_migrations = true [modules.sea_orm] enable = true path = "./tests/src/models/_entities" + +[modules.template] +enable = true +[modules.template.tables] + +[modules.model] +enable = true +prelude = true +path = "./tests/src/models" + +# [modules.annotate] +# enable = true diff --git a/process-compose.nix b/process-compose.nix index f2fb5af..4fa5bea 100644 --- a/process-compose.nix +++ b/process-compose.nix @@ -3,7 +3,7 @@ processes = { frontend = { command = '' - RUST_LOG=debug,sqlx=warn ${pkgs.cargo-watch}/bin/cargo-watch -x 'run' + RUST_LOG=debug,sqlx=warn ${pkgs.cargo-watch}/bin/cargo-watch -i tests/src -x 'run' ''; }; }; diff --git a/src/generator/discover.rs b/src/generator/discover.rs deleted file mode 100644 index 0a5f6a1..0000000 --- a/src/generator/discover.rs +++ /dev/null @@ -1,154 +0,0 @@ -use core::time; - -use color_eyre::eyre::{eyre, ContextCompat, Report, Result}; -use sea_schema::sea_query::TableCreateStatement; -use url::Url; - -use crate::config::db::DbConfig; -#[derive(Debug, Clone)] -pub enum DbType { - MySql, - Postgres, - Sqlite, -} - -pub async fn get_tables( - database_url: String, - filter: Box bool>, - database_config: &DbConfig, -) -> Result<(Vec, DbType)> { - let url = Url::parse(&database_url)?; - - tracing::trace!(?url); - - let is_sqlite = url.scheme() == "sqlite"; - - let database_name: &str = (if !is_sqlite { - let database_name = url - .path_segments() - .context("No database name as part of path")? - .next() - .context("No database name as part of path")?; - - if database_name.is_empty() { - return Err(eyre!("Database path name is empty")); - } - Ok::<&str, Report>(database_name) - } else { - Ok(Default::default()) - })?; - - let (table_stmts, db_type) = match url.scheme() { - "mysql" => { - use sea_schema::mysql::discovery::SchemaDiscovery; - use sqlx::MySql; - - tracing::info!("Connecting to MySQL"); - let connection = sqlx_connect::( - database_config.max_connections, - database_config.acquire_timeout, - url.as_str(), - None, - ) - .await?; - - tracing::info!("Discovering schema"); - let schema_discovery = SchemaDiscovery::new(connection, database_name); - let schema = schema_discovery.discover().await?; - let table_stmts = schema - .tables - .into_iter() - .filter(|schema| filter(&schema.info.name)) - .map(|schema| schema.write()) - .collect(); - (table_stmts, DbType::MySql) - } - "sqlite" => { - use sea_schema::sqlite::discovery::SchemaDiscovery; - use sqlx::Sqlite; - - tracing::info!("Connecting to SQLite"); - let connection = sqlx_connect::( - database_config.max_connections, - database_config.acquire_timeout, - url.as_str(), - None, - ) - .await?; - - tracing::info!("Discovering schema"); - let schema_discovery = SchemaDiscovery::new(connection); - let schema = schema_discovery - .discover() - .await? - .merge_indexes_into_table(); - let table_stmts = schema - .tables - .into_iter() - .filter(|schema| filter(&schema.name)) - .map(|schema| schema.write()) - .collect(); - (table_stmts, DbType::Sqlite) - } - "postgres" | "potgresql" => { - use sea_schema::postgres::discovery::SchemaDiscovery; - use sqlx::Postgres; - - tracing::info!("Connecting to Postgres"); - let schema = &database_config - .database_schema - .as_deref() - .unwrap_or("public"); - let connection = sqlx_connect::( - database_config.max_connections, - database_config.acquire_timeout, - url.as_str(), - Some(schema), - ) - .await?; - tracing::info!("Discovering schema"); - let schema_discovery = SchemaDiscovery::new(connection, schema); - let schema = schema_discovery.discover().await?; - tracing::info!(?schema); - let table_stmts = schema - .tables - .into_iter() - .filter(|schema| filter(&schema.info.name)) - .map(|schema| schema.write()) - .collect(); - (table_stmts, DbType::Postgres) - } - _ => unimplemented!("{} is not supported", url.scheme()), - }; - tracing::info!("Schema discovered"); - - Ok((table_stmts, db_type)) -} -async fn sqlx_connect( - max_connections: u32, - acquire_timeout: u64, - url: &str, - schema: Option<&str>, -) -> Result> -where - DB: sqlx::Database, - for<'a> &'a mut ::Connection: sqlx::Executor<'a>, -{ - let mut pool_options = sqlx::pool::PoolOptions::::new() - .max_connections(max_connections) - .acquire_timeout(time::Duration::from_secs(acquire_timeout)); - // Set search_path for Postgres, E.g. Some("public") by default - // MySQL & SQLite connection initialize with schema `None` - if let Some(schema) = schema { - let sql = format!("SET search_path = '{schema}'"); - pool_options = pool_options.after_connect(move |conn, _| { - let sql = sql.clone(); - Box::pin(async move { - sqlx::Executor::execute(conn, sql.as_str()) - .await - .map(|_| ()) - }) - }); - } - pool_options.connect(url).await.map_err(Into::into) -} diff --git a/src/generator/file.rs b/src/generator/file.rs index 500c280..a4491d2 100644 --- a/src/generator/file.rs +++ b/src/generator/file.rs @@ -1,6 +1,7 @@ use color_eyre::Result; use path_clean::PathClean; use std::{collections::HashMap, path::PathBuf}; +use tokio::{fs::File, io::AsyncWriteExt}; pub fn pathbuf_to_rust_path(path: PathBuf) -> String { let clean_path = path.clean(); @@ -29,6 +30,7 @@ pub fn pathbuf_to_rust_path(path: PathBuf) -> String { #[derive(Debug, Clone)] pub enum InsertPoint { Start, + Replace(String), End, } #[derive(Debug, Clone)] @@ -47,20 +49,44 @@ impl FileManager { files: HashMap::new(), } } - pub fn insert_file( + pub fn insert( &mut self, - file: PathBuf, - content: String, + file: &PathBuf, + content: &str, insert_point: Option, ) -> Result<()> { - if let Some(file) = self.files.get_mut(&file) { + if let Some(file) = self.files.get_mut(file) { match insert_point { - Some(InsertPoint::Start) => file.content.insert_str(0, &content), - Some(InsertPoint::End) => file.content.push_str(&content), - None => file.content.push_str(&content), + Some(InsertPoint::Start) => file.content.insert_str(0, content), + Some(InsertPoint::End) => file.content.push_str(content), + None => file.content.push_str(content), + Some(InsertPoint::Replace(replace)) => { + let content = file.content.replace(&replace, content); + file.content = content; + } } } else { - self.files.insert(file.clone(), FileContent { content }); + self.files.insert( + file.clone(), + FileContent { + content: content.to_string(), + }, + ); + } + Ok(()) + } + pub fn get(&self, file: &PathBuf) -> Option<&FileContent> { + self.files.get(file) + } + pub async fn write_files(&self) -> Result<()> { + for (file, content) in &self.files { + tracing::debug!(?file, "Writing file"); + let parent = file.parent().unwrap(); + if !parent.exists() { + tokio::fs::create_dir_all(parent).await?; + } + let mut opened_file = File::create(file).await?; + opened_file.write_all(content.content.as_bytes()).await?; } Ok(()) } @@ -68,7 +94,7 @@ impl FileManager { #[cfg(test)] mod test { - use crate::generator::file::pathbuf_to_rust_path; + use crate::generator::file::{pathbuf_to_rust_path, FileManager, InsertPoint}; use std::path::PathBuf; #[test] fn test_pathbuf_to_rust_path() { @@ -90,4 +116,49 @@ mod test { let rust_path = pathbuf_to_rust_path(path); assert_eq!(rust_path, ""); } + #[test] + fn test_fildmanager_insert() { + let mut file_manager = FileManager::new(); + let file_path = PathBuf::from("test.rs"); + file_manager.insert(&file_path, "test", None).unwrap(); + file_manager.insert(&file_path, "test1", None).unwrap(); + assert_eq!(file_manager.get(&file_path).unwrap().content, "testtest1"); + } + #[test] + fn test_fildmanager_insert_start() { + let mut file_manager = FileManager::new(); + let file_path = PathBuf::from("test.rs"); + file_manager.insert(&file_path, "test", None).unwrap(); + file_manager + .insert(&file_path, "teststart", Some(InsertPoint::Start)) + .unwrap(); + assert_eq!( + file_manager.get(&file_path).unwrap().content, + "teststarttest" + ); + } + #[test] + fn test_fildmanager_insert_end() { + let mut file_manager = FileManager::new(); + let file_path = PathBuf::from("test.rs"); + file_manager.insert(&file_path, "test", None).unwrap(); + file_manager + .insert(&file_path, "testend", Some(InsertPoint::End)) + .unwrap(); + assert_eq!(file_manager.get(&file_path).unwrap().content, "testtestend"); + } + #[test] + fn test_fildmanager_insert_replace() { + let mut file_manager = FileManager::new(); + let file_path = PathBuf::from("test.rs"); + file_manager.insert(&file_path, "test", None).unwrap(); + file_manager + .insert( + &file_path, + "testreplace", + Some(InsertPoint::Replace("test".to_string())), + ) + .unwrap(); + assert_eq!(file_manager.get(&file_path).unwrap().content, "testreplace"); + } } diff --git a/src/generator/mod.rs b/src/generator/mod.rs index a83a944..889e5bb 100644 --- a/src/generator/mod.rs +++ b/src/generator/mod.rs @@ -1,4 +1,3 @@ -pub mod discover; pub mod file; pub mod modules; use color_eyre::Result; @@ -15,6 +14,11 @@ pub async fn generate(database_url: &str, root_config: DocumentMut) -> Result<() .insert(DatabaseUrl(database_url.to_owned())); module_manager.validate().await?; module_manager.execute().await?; + module_manager + .get_context_mut() + .get_file_manager() + .write_files() + .await?; // let db_filter = config.sea_orm.entity.tables.get_filter(); // let (table_stmts, db_type) = diff --git a/src/generator/modules/annotate/mod.rs b/src/generator/modules/annotate/mod.rs new file mode 100644 index 0000000..1f67661 --- /dev/null +++ b/src/generator/modules/annotate/mod.rs @@ -0,0 +1,43 @@ +use super::{models::ModelsConfig, Module, ModulesContext}; +use color_eyre::Result; +use serde::Deserialize; +#[derive(Debug, Clone, Deserialize)] +#[serde(default)] +pub struct AnnotateConfig { + pub enable: bool, +} + +impl Default for AnnotateConfig { + fn default() -> Self { + Self { enable: false } + } +} + +#[derive(Debug, Default)] +pub struct AnnotateModule { + pub models: bool, +} + +#[async_trait::async_trait] +impl Module for AnnotateModule { + fn init(&mut self, ctx: &mut ModulesContext) -> Result<()> { + ctx.get_config_auto::("modules.annotate")?; + Ok(()) + } + + async fn validate(&mut self, ctx: &mut ModulesContext) -> Result { + let map = ctx.get_anymap(); + + if let (Some(config), Some(models_config)) = + (map.get::(), map.get::()) + { + Ok(config.enable) + } else { + // One or both keys are missing + Ok(false) + } + } + async fn execute(&mut self, ctx: &mut ModulesContext) -> Result<()> { + Ok(()) + } +} diff --git a/src/generator/modules/discovery/column.rs b/src/generator/modules/discovery/column.rs index 97dfe82..0d51cbc 100644 --- a/src/generator/modules/discovery/column.rs +++ b/src/generator/modules/discovery/column.rs @@ -2,7 +2,7 @@ use color_eyre::{eyre::ContextCompat, Result}; use heck::ToUpperCamelCase; use sea_schema::sea_query::{ColumnDef, ColumnSpec, ColumnType, IndexCreateStatement}; -use crate::config::sea_orm_config::DateTimeCrate; +use crate::generator::modules::sea_orm::config::DateTimeCrate; use super::db::DbType; #[derive(Clone, Debug)] diff --git a/src/generator/modules/discovery/db.rs b/src/generator/modules/discovery/db.rs index 0e93e88..124b9bb 100644 --- a/src/generator/modules/discovery/db.rs +++ b/src/generator/modules/discovery/db.rs @@ -4,8 +4,6 @@ use color_eyre::eyre::{eyre, ContextCompat, Report, Result}; use sea_schema::sea_query::TableCreateStatement; use url::Url; -use crate::config::db::DbConfig; - use super::DiscoveryConfig; #[derive(Debug, Clone)] pub enum DbType { diff --git a/src/generator/modules/discovery/mod.rs b/src/generator/modules/discovery/mod.rs index 01d9610..e2871cc 100644 --- a/src/generator/modules/discovery/mod.rs +++ b/src/generator/modules/discovery/mod.rs @@ -8,7 +8,6 @@ use color_eyre::Result; use db::DbType; use sea_schema::sea_query::TableCreateStatement; use serde::Deserialize; -use serde_inline_default::serde_inline_default; use table::Table; #[derive(Debug, Clone, Deserialize)] diff --git a/src/generator/modules/mod.rs b/src/generator/modules/mod.rs index 2417274..abd2347 100644 --- a/src/generator/modules/mod.rs +++ b/src/generator/modules/mod.rs @@ -1,25 +1,18 @@ -use std::{ - fmt::Debug, - sync::{Arc, MutexGuard}, -}; +use std::fmt::Debug; -use anymap::{ - any::{Any, CloneAny}, - Map, -}; +use annotate::AnnotateModule; +use anymap::{any::CloneAny, Map}; use color_eyre::{eyre::eyre, Result}; use discovery::DiscoveryModule; -use sea_orm::{SeaOrmConfig, SeaOrmModule}; -// use models::ModelsModule; +use models::ModelsModule; +use sea_orm::SeaOrmModule; use serde::{de::IntoDeserializer, Deserialize}; -use std::sync::Mutex; use templates::TemplateModule; -use toml_edit::{de::ValueDeserializer, DocumentMut, Item, Value}; -// use models::table::Table; -// -// use super::discover::DbType; +use toml_edit::{DocumentMut, Item}; +use super::file::FileManager; type AnyCloneMap = Map; +pub mod annotate; pub mod discovery; pub mod models; pub mod sea_orm; @@ -28,12 +21,14 @@ pub mod templates; pub struct ModulesContext { pub anymap: AnyCloneMap, pub root_config: DocumentMut, + pub file_manager: FileManager, } impl ModulesContext { pub fn new(root_config: DocumentMut) -> Self { Self { anymap: AnyCloneMap::new(), root_config, + file_manager: FileManager::new(), } } pub fn get_config_raw(&self, path: &str) -> Result<&Item> { @@ -44,12 +39,12 @@ impl ModulesContext { if let Some(v) = item.get(i) { *item = v; } else { - return Err(eyre!("Config not found")); + return Err(eyre!("Config not found \"{i}\"")); } } else if let Some(v) = self.root_config.get(i) { item = Some(v); } else { - return Err(eyre!("Config not found")); + return Err(eyre!("Config not found \"{i}\"")); } } if let Some(v) = item { @@ -58,8 +53,10 @@ impl ModulesContext { Err(eyre!("Config not found")) } } - pub fn get_config<'a, V: Deserialize<'a> + Debug>(&self, path: &str) -> Result { - let item = self.get_config_raw(path)?; + pub fn get_config<'a, V: Deserialize<'a> + Debug>(&self, path: &str) -> Result> { + let Ok(item) = self.get_config_raw(path) else { + return Ok(None); + }; let value = item .clone() .into_value() @@ -67,26 +64,28 @@ impl ModulesContext { let deserializer = value.into_deserializer(); let config = V::deserialize(deserializer)?; tracing::debug!(?config, "{}", path); - Ok(config) + Ok(Some(config)) } - pub fn get_config_auto<'a, V: Deserialize<'a> + Clone + Send + Debug + 'static>( + pub fn get_config_auto<'a, V: Deserialize<'a> + Clone + Send + Debug + Default + 'static>( &mut self, path: &str, ) -> Result<()> { - let value: V = self.get_config::(path)?; - self.get_anymap_mut().insert(value); + let value: Option = self.get_config::(path)?; + if value.is_none() { + tracing::warn!(?path, "Config not found, using default"); + } + self.get_anymap_mut().insert(value.unwrap_or_default()); Ok(()) } - // pub fn get_anymap(&self) -> MutexGuard { - // let v = self.anymap.lock().unwrap(); - // v - // } pub fn get_anymap(&self) -> &AnyCloneMap { &self.anymap } pub fn get_anymap_mut(&mut self) -> &mut AnyCloneMap { &mut self.anymap } + pub fn get_file_manager(&mut self) -> &mut FileManager { + &mut self.file_manager + } } #[async_trait::async_trait] pub trait Module: Debug { @@ -105,7 +104,9 @@ impl ModuleManager { let modules: Vec> = vec![ Box::new(TemplateModule), Box::new(DiscoveryModule), - Box::new(SeaOrmModule), //Box::new(ModelsModule) + Box::new(SeaOrmModule), + Box::new(ModelsModule), + Box::new(AnnotateModule::default()), ]; Self { modules, diff --git a/src/generator/modules/models/mod.rs b/src/generator/modules/models/mod.rs index 06baf33..18f085e 100644 --- a/src/generator/modules/models/mod.rs +++ b/src/generator/modules/models/mod.rs @@ -1,27 +1,167 @@ -use super::{Module, ModulesContext}; -use color_eyre::Result; -use serde::Deserialize; +use std::path::PathBuf; + +use crate::generator::file::pathbuf_to_rust_path; + +use super::{ + discovery::DiscoveredSchema, sea_orm::SeaOrmConfig, templates::TemplateConfig, Module, + ModulesContext, +}; +use color_eyre::{ + eyre::{eyre, Context, ContextCompat}, + Result, +}; +use handlebars::Handlebars; +use heck::ToPascalCase; +use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Deserialize)] +#[serde(default)] pub struct ModelsConfig { pub enable: bool, - pub database_schema: String, - pub max_connections: u32, - pub acquire_timeout: u32, + pub path: Option, + pub prelude: bool, } -// #[derive(Debug)] -// pub struct ModelsModule; -// -// #[async_trait::async_trait] -// impl Module for ModelsModule { -// fn init(&self, ctx: &mut ModulesContext) -> Result<()> { -// Ok(()) -// } -// -// async fn validate(&self, ctx: &mut ModulesContext) -> Result { -// Ok(false) -// } -// } +impl Default for ModelsConfig { + fn default() -> Self { + Self { + enable: false, + path: None, + prelude: true, + } + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct ModelTemplateContext { + entities_path: String, + model_path: String, + model_name: String, + entity_name: String, + active_model_name: String, + prelude_path: String, +} + +impl ModelTemplateContext { + pub fn new(entities_path: String, model_name: String, prelude_path: String) -> Self { + let model_path = model_name.clone(); + let active_model_name = format!("{}ActiveModel", model_name).to_pascal_case(); + let model_name = format!("{}Model", model_name).to_pascal_case(); + let entity_name = model_path.clone().to_pascal_case(); + Self { + entities_path, + model_path, + model_name, + active_model_name, + prelude_path, + entity_name, + } + } +} + +#[derive(Debug)] +pub struct ModelsModule; + +#[async_trait::async_trait] +impl Module for ModelsModule { + fn init(&mut self, ctx: &mut ModulesContext) -> Result<()> { + ctx.get_config_auto::("modules.model")?; + Ok(()) + } + + async fn validate(&mut self, ctx: &mut ModulesContext) -> Result { + let map = ctx.get_anymap(); + + if let (Some(config), Some(template_config), Some(sea_orm_config)) = ( + map.get::(), + map.get::(), + map.get::(), + ) { + if config.enable && !template_config.enable { + return Err(eyre!( + "\"modules.template.enable\" must be enabled to use \"modules.model.enable\"" + )); + } + if config.enable && !sea_orm_config.enable { + return Err(eyre!( + "\"modules.sea_orm.enable\" must be enabled to use \"modules.model.enable\"" + )); + } + if config.enable && config.path.is_none() { + return Err(eyre!( + "\"modules.model.path\" must be set to use \"modules.model.enable\"" + )); + } + Ok(config.enable && template_config.enable) + } else { + // One or both keys are missing + Ok(false) + } + } + async fn execute(&mut self, ctx: &mut ModulesContext) -> Result<()> { + let mut files: Vec<(PathBuf, String)> = Vec::new(); + let map = ctx.get_anymap(); + + if let (Some(config), Some(templates), Some(sea_orm_config), Some(schema)) = ( + map.get::(), + map.get::>(), + map.get::(), + map.get::(), + ) { + let models_path = config.path.clone().unwrap(); + tracing::info!(?models_path, "Models path"); + let entities_path = sea_orm_config.path.clone().unwrap(); + let mod_path = models_path.join("mod.rs"); + if config.prelude { + files.push((mod_path.clone(), "pub mod prelude;".to_string())); + } + for table in &schema.tables { + tracing::debug!(?table, "Generating model for table"); + let path = models_path.join(format!("{}.rs", table.name)); + + let relative_entities_path = pathdiff::diff_paths(&entities_path, &path) + .context("Failed to calculate relative path")?; + let relative_entities_rust_path = pathbuf_to_rust_path(relative_entities_path); + let context = ModelTemplateContext::new( + relative_entities_rust_path.clone(), + table.name.clone(), + "super::prelude".to_string(), + ); + if config.prelude { + let prelude_part = templates + .render("model_prelude_part", &context) + .context("Failed to render model prelude part")?; + files.push((models_path.join("prelude.rs"), prelude_part.clone())); + } + + files.push((mod_path.clone(), format!("pub mod {};", table.name))); + if path.exists() { + tracing::debug!(?path, "Model file already exists"); + continue; + } + + if config.prelude { + let content = templates + .render("model_prelude", &context) + .context("Failed to render model prelude")?; + files.push((path.clone(), content.clone())); + } else { + let content = templates + .render("model", &context) + .context("Failed to render model")?; + files.push((path.clone(), content.clone())); + } + } + } else { + // One or both keys are missing + } + tracing::info!(?files, "Generated model files"); + let file_manager = ctx.get_file_manager(); + for (output_path, content) in files { + file_manager.insert(&output_path, &content, None)?; + } + Ok(()) + } +} // // // use crate::{ diff --git a/src/generator/modules/sea_orm/config.rs b/src/generator/modules/sea_orm/config.rs index 7039208..ca22082 100644 --- a/src/generator/modules/sea_orm/config.rs +++ b/src/generator/modules/sea_orm/config.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Deserializer, Serialize}; -use serde_yaml::Value; +use toml::Value; use sea_orm_codegen::{DateTimeCrate as CodegenDateTimeCrate, WithPrelude, WithSerde}; @@ -43,8 +43,8 @@ impl<'de> Deserialize<'de> for SerdeEnable { match value { Value::String(s) if s == "serialize" => Ok(SerdeEnable::Serialize), Value::String(s) if s == "deserialize" => Ok(SerdeEnable::Deserialize), - Value::Bool(true) => Ok(SerdeEnable::Both), - Value::Bool(false) => Ok(SerdeEnable::None), + Value::Boolean(true) => Ok(SerdeEnable::Both), + Value::Boolean(false) => Ok(SerdeEnable::None), _ => Err(serde::de::Error::custom( "expected 'serialize', 'deserialize', 'true' or 'false'", )), @@ -72,8 +72,8 @@ impl<'de> Deserialize<'de> for Prelude { let value = Value::deserialize(deserializer)?; match value { - Value::Bool(true) => Ok(Prelude::Enabled), - Value::Bool(false) => Ok(Prelude::Disabled), + Value::Boolean(true) => Ok(Prelude::Enabled), + Value::Boolean(false) => Ok(Prelude::Disabled), Value::String(s) if s == "allow_unused_imports" => Ok(Prelude::AllowUnusedImports), _ => Err(serde::de::Error::custom( "expected 'true', 'false', or 'allow_unused_imports'", diff --git a/src/generator/modules/sea_orm/mod.rs b/src/generator/modules/sea_orm/mod.rs index 2c1cf38..1d76f1c 100644 --- a/src/generator/modules/sea_orm/mod.rs +++ b/src/generator/modules/sea_orm/mod.rs @@ -66,6 +66,7 @@ impl Module for SeaOrmModule { } async fn execute(&mut self, ctx: &mut ModulesContext) -> Result<()> { let map = ctx.get_anymap(); + let mut outputs = vec![]; if let (Some(statements), Some(config), Some(discovery_config)) = ( map.get::(), map.get::(), @@ -86,18 +87,20 @@ impl Module for SeaOrmModule { config.entity.extra_derives.eenum.clone(), config.entity.extra_attributes.eenum.clone(), false, - true, + false, ); let output = EntityTransformer::transform(statements.statements.clone())? .generate(&writer_context); - for file in output.files { + outputs.extend(output.files.into_iter().map(|file| { let file_path = config.path.clone().unwrap_or_default(); let file_path = file_path.join(file.name); - tracing::info!(?file_path, "Generating file"); + (file_path, file.content) + })); + } - // let mut file_generator = crate::generator::file::FileGenerator::new(file_path); - // file_generator.write(file.content)?; - } + let file_manager = ctx.get_file_manager(); + for (output_path, content) in outputs { + file_manager.insert(&output_path, &content, None)?; } Ok(()) diff --git a/src/generator/modules/templates/mod.rs b/src/generator/modules/templates/mod.rs index 7cfb1b6..9f005cc 100644 --- a/src/generator/modules/templates/mod.rs +++ b/src/generator/modules/templates/mod.rs @@ -1,16 +1,84 @@ -use crate::generator::DatabaseUrl; +use std::{collections::HashMap, path::PathBuf}; use super::{Module, ModulesContext}; -use color_eyre::Result; +use color_eyre::{ + eyre::{eyre, ContextCompat}, + Result, +}; use handlebars::Handlebars; use serde::Deserialize; -use serde_inline_default::serde_inline_default; -#[serde_inline_default] +use include_dir::{include_dir, Dir, DirEntry}; +use tokio::fs; + +static TEMPLATE_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/templates"); + #[derive(Debug, Clone, Deserialize)] +#[serde(default)] pub struct TemplateConfig { - #[serde_inline_default(false)] pub enable: bool, + pub path: Option, + #[serde(flatten)] + pub tables: HashMap, +} +impl TemplateConfig { + pub fn to_paths(&self) -> Vec<(String, PathBuf)> { + let mut paths = Vec::new(); + for (key, value) in &self.tables { + let map_string = value.clone(); + paths.extend(map_string.into_paths(key.clone())); + } + let root = self.path.clone().unwrap_or_default(); + + paths + .into_iter() + .map(|(key, path)| { + let new_path = root.clone().join(path); + (key, new_path) + }) + .collect::>() + } +} +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] +#[serde(untagged)] +pub enum MapString { + Map(HashMap), + PathBuf(PathBuf), +} +impl MapString { + pub fn into_paths(self, prefix: String) -> Vec<(String, PathBuf)> { + fn write_path(prefix: String, string: MapString) -> Vec<(String, PathBuf)> { + let mut strings = Vec::new(); + match string { + MapString::Map(inner) => { + for (key, value) in inner { + let new_prefix = if prefix.is_empty() { + key + } else { + format!("{}.{}", prefix, key) + }; + strings.extend(write_path(new_prefix, value)); + } + } + MapString::PathBuf(pathbuf) => { + strings.push((prefix, pathbuf)); + } + } + strings.sort(); + strings + } + write_path(prefix, self) + } +} + +impl Default for TemplateConfig { + fn default() -> Self { + Self { + enable: true, + tables: HashMap::new(), + path: None, + } + } } #[derive(Debug)] pub struct TemplateModule; @@ -18,22 +86,118 @@ pub struct TemplateModule; #[async_trait::async_trait] impl Module for TemplateModule { fn init(&mut self, ctx: &mut ModulesContext) -> Result<()> { - let registry: Handlebars<'static> = Handlebars::new(); - ctx.get_anymap_mut().insert(registry); + ctx.get_config_auto::("modules.template")?; Ok(()) } async fn validate(&mut self, ctx: &mut ModulesContext) -> Result { - // let map = ctx.get_anymap(); + let map = ctx.get_anymap_mut(); // - // if let (Some(config), Some(_)) = (map.get::(), map.get::()) { - // Ok(config.enable) - // } else { - // // One or both keys are missing - // Ok(false) - // } - Ok(true) + if let Some(config) = map.get::() { + if config.enable { + for templates in config.to_paths() { + let path = templates.1; + if !path.exists() { + return Err(eyre!("Template path does not exist: {}", path.display())); + } + } + } + Ok(config.enable) + } else { + Ok(false) + } } async fn execute(&mut self, ctx: &mut ModulesContext) -> Result<()> { + let mut registry: Handlebars<'static> = Handlebars::new(); + if let Some(config) = ctx.get_anymap().get::() { + for (templates, path) in config.to_paths() { + tracing::debug!(?templates, ?path, "Registering template"); + let content = fs::read_to_string(path).await?; + registry + .register_template_string(templates.as_str(), content) + .map_err(|_| eyre!("Failed to register template"))?; + } + Self::register_default_templates(TEMPLATE_DIR.entries(), &mut registry).await?; + } + ctx.get_anymap_mut().insert(registry); Ok(()) } } + +impl TemplateModule { + async fn register_default_templates<'a>( + entries: &[DirEntry<'a>], + handlebars: &mut Handlebars<'a>, + ) -> Result<()> { + for entry in entries.iter().filter(|file| { + file.path() + .extension() + .is_some_and(|f| f.to_str().is_some_and(|f| f == "hbs")) + || file.as_dir().is_some() + }) { + match entry { + DirEntry::File(file) => { + let path = file.path().with_extension(""); + let name = path + .to_str() + .context("Failed to convert path to str")? + .replace("/", "."); + let content = file + .contents_utf8() + .context(format!("Template {} failed to parse", name))?; + + tracing::debug!(?name, "Registering template"); + if !handlebars.has_template(&name) { + handlebars.register_template_string(&name, content)?; + } else { + tracing::debug!(?name, "Template already registered, skipping"); + } + } + DirEntry::Dir(dir) => { + Box::pin(Self::register_default_templates(dir.entries(), handlebars)).await?; + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + // #[test] + // fn test_map_string() { + // use super::MapString; + // let map = MapString::Map( + // vec![ + // ("a".to_string(), MapString::PathBuf("b".to_string())), + // ("c".to_string(), MapString::PathBuf("d".to_string())), + // ( + // "e".to_string(), + // MapString::Map( + // vec![("f".to_string(), MapString::String("g".to_string()))] + // .into_iter() + // .collect(), + // ), + // ), + // ] + // .into_iter() + // .collect(), + // ); + // let paths = map.into_paths("".to_string()); + // assert_eq!( + // paths, + // vec![ + // ("a".to_string(), "b".to_string()), + // ("c".to_string(), "d".to_string()), + // ("e.f".to_string(), "g".to_string()) + // ] + // ); + // } + // #[test] + // fn test_map_string2() { + // use super::MapString; + // let map = MapString::String("a".to_string()); + // let paths = map.to_paths(); + // assert_eq!(paths, vec!["a"]); + // } +} diff --git a/src/main.rs b/src/main.rs index ba96ad6..fef0124 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,9 @@ -mod config; +// mod config; mod generator; mod templates; use clap::Parser; use color_eyre::{eyre::eyre, Result}; -use config::Config; -use figment::{ - providers::{Format, Serialized, Yaml}, - Figment, -}; use handlebars::Handlebars; use tokio::{fs, io::AsyncWriteExt, process::Command}; use toml_edit::DocumentMut; @@ -44,7 +39,7 @@ async fn main() -> Result<()> { let config = fs::read_to_string(args.config).await?; let root_config = config.parse::()?; - let outputs = generator::generate(&args.database_url, root_config).await?; + generator::generate(&args.database_url, root_config).await?; // // // tracing::info!(?outputs, "Generated files"); // for output in outputs.iter() { diff --git a/src/templates.rs b/src/templates.rs index 99bb4e8..532fcdb 100644 --- a/src/templates.rs +++ b/src/templates.rs @@ -1,67 +1,28 @@ -use crate::config::Config; use color_eyre::eyre::{ContextCompat, Result}; use handlebars::Handlebars; -use include_dir::{include_dir, Dir, DirEntry}; -use serde_yaml::Value; use std::path::PathBuf; use tokio::fs; -static TEMPLATE_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/templates"); +// static TEMPLATE_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/templates"); -async fn handle_direntries<'a>( - entries: &[DirEntry<'a>], - handlebars: &mut Handlebars<'a>, -) -> Result<()> { - for entry in entries.iter().filter(|file| { - file.path() - .extension() - .is_some_and(|f| f.to_str().is_some_and(|f| f == "hbs")) - || file.as_dir().is_some() - }) { - match entry { - DirEntry::File(file) => { - let path = file.path().with_extension(""); - let name = path - .to_str() - .context("Failed to convert path to str")? - .replace("/", "."); - let content = file - .contents_utf8() - .context(format!("Template {} failed to parse", name))?; - - tracing::debug!(?name, "Registering template"); - if !handlebars.has_template(&name) { - handlebars.register_template_string(&name, content)?; - } else { - tracing::debug!(?name, "Template already registered, skipping"); - } - } - DirEntry::Dir(dir) => { - Box::pin(handle_direntries(dir.entries(), handlebars)).await?; - } - } - } - - Ok(()) -} -pub async fn register_templates(handlebars: &mut Handlebars<'_>, config: &Config) -> Result<()> { - if let Some(templates) = &config.templates { - for (name, value) in templates.iter() { - let Value::String(name) = name else { - return Err(color_eyre::eyre::eyre!("Invalid template name")); - }; - let Value::String(path) = value else { - return Err(color_eyre::eyre::eyre!("Invalid template value")); - }; - let mut path = PathBuf::from(path); - if let Some(templates_dir) = &config.templates_dir { - path = templates_dir.join(path); - } - tracing::info!(?name, ?path, "Registering template"); - let content = fs::read_to_string(path).await?; - handlebars.register_template_string(name, content)?; - } - } - handle_direntries(TEMPLATE_DIR.entries(), handlebars).await?; - Ok(()) -} +// pub async fn register_templates(handlebars: &mut Handlebars<'_>, config: &Config) -> Result<()> { +// // if let Some(templates) = &config.templates { +// // for (name, value) in templates.iter() { +// // let Value::String(name) = name else { +// // return Err(color_eyre::eyre::eyre!("Invalid template name")); +// // }; +// // let Value::String(path) = value else { +// // return Err(color_eyre::eyre::eyre!("Invalid template value")); +// // }; +// // let mut path = PathBuf::from(path); +// // if let Some(templates_dir) = &config.templates_dir { +// // path = templates_dir.join(path); +// // } +// // tracing::info!(?name, ?path, "Registering template"); +// // let content = fs::read_to_string(path).await?; +// // handlebars.register_template_string(name, content)?; +// // } +// // } +// handle_direntries(TEMPLATE_DIR.entries(), handlebars).await?; +// Ok(()) +// } diff --git a/templates/modelprelude.hbs b/templates/model_prelude.hbs similarity index 100% rename from templates/modelprelude.hbs rename to templates/model_prelude.hbs diff --git a/templates/model_prelude_part.hbs b/templates/model_prelude_part.hbs new file mode 100644 index 0000000..ddba2a8 --- /dev/null +++ b/templates/model_prelude_part.hbs @@ -0,0 +1 @@ +pub use {{entities_path}}::{{model_name}}::{ActiveModel as {{active_model_name}}, Model as {{model_name}}, Entity as {{entity_name}}}; diff --git a/tests/src/models/_entities/mod.rs b/tests/src/models/_entities/mod.rs new file mode 100644 index 0000000..a06c966 --- /dev/null +++ b/tests/src/models/_entities/mod.rs @@ -0,0 +1,3 @@ +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.8 + +pub mod user ; \ No newline at end of file diff --git a/tests/src/models/_entities/user.rs b/tests/src/models/_entities/user.rs new file mode 100644 index 0000000..f2f3243 --- /dev/null +++ b/tests/src/models/_entities/user.rs @@ -0,0 +1,9 @@ +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.8 + + + +use sea_orm :: entity :: prelude :: * ; + +# [derive (Clone , Debug , PartialEq , DeriveEntityModel , Eq)] # [sea_orm (table_name = "user")] pub struct Model { # [sea_orm (primary_key)] pub id : i32 , # [sea_orm (unique)] pub username : String , # [sea_orm (unique)] pub email : String , pub password : String , # [sea_orm (unique)] pub test : String , } + +# [derive (Copy , Clone , Debug , EnumIter , DeriveRelation)] pub enum Relation { } \ No newline at end of file diff --git a/tests/src/models/mod.rs b/tests/src/models/mod.rs new file mode 100644 index 0000000..852af71 --- /dev/null +++ b/tests/src/models/mod.rs @@ -0,0 +1 @@ +pub mod prelude;pub mod user; \ No newline at end of file diff --git a/tests/src/models/prelude.rs b/tests/src/models/prelude.rs new file mode 100644 index 0000000..c69a8a7 --- /dev/null +++ b/tests/src/models/prelude.rs @@ -0,0 +1 @@ +pub use super::_entities::UserModel::{ActiveModel as UserActiveModel, Model as UserModel, Entity as User}; diff --git a/tests/src/models/user.rs b/tests/src/models/user.rs new file mode 100644 index 0000000..f35815e --- /dev/null +++ b/tests/src/models/user.rs @@ -0,0 +1,9 @@ +use super::prelude::*; +use sea_orm::ActiveModelBehavior; + +#[async_trait::async_trait] +impl ActiveModelBehavior for UserActiveModel {} + +impl UserModel {} + +impl UserActiveModel {}