From f38051372539716771b84f728ab0de05ef322a44 Mon Sep 17 00:00:00 2001 From: Nikkuss Date: Fri, 4 Apr 2025 22:04:06 +0400 Subject: [PATCH] rewrite generator entirely, again --- Cargo.toml | 2 +- src/{config.rs => config.rs.old} | 47 ++++- src/config/db.rs | 8 + src/config/mod.rs | 86 +++++++++ src/config/output.rs | 32 ++++ src/config/sea_orm_config.rs | 233 ++++++++++++++++++++++++ src/config/template.rs | 1 + src/generator/column.rs | 32 ---- src/generator/file.rs | 113 +++--------- src/generator/mod.rs | 96 +++++++--- src/generator/models/column.rs | 236 +++++++++++++++++++++++++ src/generator/models/comment.rs | 154 ++++++++++++++++ src/generator/{ => models}/discover.rs | 48 +++-- src/generator/models/file.rs | 101 +++++++++++ src/generator/models/mod.rs | 153 ++++++++++++++++ src/generator/{ => models}/table.rs | 8 +- src/main.rs | 97 ++++++---- src/templates.rs | 67 +++++++ src/templates/model.rs.hbs | 12 -- templates/model.hbs | 9 + templates/modelprelude.hbs | 9 + 21 files changed, 1323 insertions(+), 221 deletions(-) rename src/{config.rs => config.rs.old} (86%) create mode 100644 src/config/db.rs create mode 100644 src/config/mod.rs create mode 100644 src/config/output.rs create mode 100644 src/config/sea_orm_config.rs create mode 100644 src/config/template.rs delete mode 100644 src/generator/column.rs create mode 100644 src/generator/models/column.rs create mode 100644 src/generator/models/comment.rs rename src/generator/{ => models}/discover.rs (80%) create mode 100644 src/generator/models/file.rs create mode 100644 src/generator/models/mod.rs rename src/generator/{ => models}/table.rs (87%) create mode 100644 src/templates.rs delete mode 100644 src/templates/model.rs.hbs create mode 100644 templates/model.hbs create mode 100644 templates/modelprelude.hbs diff --git a/Cargo.toml b/Cargo.toml index f3281ab..22832de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] clap = { version = "4.5.32", features = ["derive", "env"] } color-eyre = "0.6.3" +comfy-table = { version = "7.1.4", default-features = false } comment-parser = "0.1.0" figment = { version = "0.10.19", features = ["yaml"] } handlebars = "6.3.2" @@ -13,7 +14,6 @@ heck = "0.5.0" include_dir = "0.7.4" indicatif = "0.17.11" inquire = "0.7.5" -prettytable = "0.10.0" quote = "1.0.40" sea-orm-codegen = "1.1.8" sea-schema = { version = "0.16.1", features = ["sqlx-all"] } diff --git a/src/config.rs b/src/config.rs.old similarity index 86% rename from src/config.rs rename to src/config.rs.old index f1113b9..85e0866 100644 --- a/src/config.rs +++ b/src/config.rs.old @@ -5,7 +5,7 @@ use sea_orm_codegen::{ DateTimeCrate as CodegenDateTimeCrate, EntityWriterContext, WithPrelude, WithSerde, }; use serde::{Deserialize, Deserializer, Serialize}; -use serde_yaml::Value; +use serde_yaml::{Mapping, Value}; use tracing::instrument; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -104,12 +104,34 @@ pub struct Config { pub db: DbConfig, pub sea_orm: SeaOrmConfig, pub output: OutputConfig, - pub templates: TemplateConfig, + pub templates: Option, + pub templates_dir: Option, } #[derive(Deserialize, Serialize, Debug, Clone)] pub struct OutputConfig { pub path: PathBuf, + pub models: OutputModelConfig, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct OutputCommentConfig { + pub enable: bool, + pub max_width: Option, + pub table_name: bool, + pub column_info: bool, + pub column_name: bool, + pub column_db_type: bool, + pub column_rust_type: bool, + pub column_attributes: bool, + pub column_exclude_attributes: Vec, +} +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct OutputModelConfig { + pub prelude: bool, + pub path: PathBuf, + pub comment: OutputCommentConfig, + pub entities: String, } #[derive(Deserialize, Serialize, Debug, Clone)] @@ -246,9 +268,26 @@ impl Default for Config { }, }, output: OutputConfig { - path: PathBuf::from("./entities"), + path: PathBuf::from("./src/"), + models: OutputModelConfig { + comment: OutputCommentConfig { + max_width: None, + table_name: true, + column_name: true, + column_db_type: true, + column_rust_type: true, + column_attributes: true, + column_exclude_attributes: Vec::new(), + enable: true, + column_info: true, + }, + entities: String::from("_entities"), + prelude: true, + path: PathBuf::from("./models"), + }, }, - templates: TemplateConfig { model: None }, + templates: None, + templates_dir: None, } } } diff --git a/src/config/db.rs b/src/config/db.rs new file mode 100644 index 0000000..a650faa --- /dev/null +++ b/src/config/db.rs @@ -0,0 +1,8 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct DbConfig { + pub database_schema: Option, + pub max_connections: u32, + pub acquire_timeout: u64, +} diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..05b784a --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,86 @@ +pub mod db; +pub mod output; +pub mod sea_orm_config; +pub mod template; + +use std::path::PathBuf; + +use db::DbConfig; +use output::{OutputCommentConfig, OutputConfig, OutputModelConfig}; +use sea_orm_config::{ + DateTimeCrate, EntityFormat, Prelude, SeaOrmConfig, SeaOrmEntityConfig, + SeaOrmExtraAttributesConfig, SeaOrmExtraDerivesConfig, SeaOrmSerdeConfig, SeaOrmTableConfig, + SerdeEnable, +}; +use serde::{Deserialize, Serialize}; +use serde_yaml::Mapping; + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct Config { + pub db: DbConfig, + pub sea_orm: SeaOrmConfig, + pub output: OutputConfig, + pub templates: Option, + pub templates_dir: Option, +} + +impl Default for Config { + fn default() -> Self { + Self { + db: DbConfig { + database_schema: None, + max_connections: 10, + acquire_timeout: 5, + }, + sea_orm: SeaOrmConfig { + prelude: Prelude::Enabled, + serde: SeaOrmSerdeConfig { + enable: SerdeEnable::None, + skip_deserializing_primary_key: false, + skip_hidden_column: false, + }, + entity: SeaOrmEntityConfig { + format: EntityFormat::Compact, + tables: SeaOrmTableConfig { + include_hidden: false, + skip_seaql_migrations: true, + table_config: None, + }, + extra_derives: SeaOrmExtraDerivesConfig { + model: Vec::new(), + eenum: Vec::new(), + }, + extra_attributes: SeaOrmExtraAttributesConfig { + model: Vec::new(), + eenum: Vec::new(), + }, + date_time_crate: DateTimeCrate::Chrono, + with_copy_enums: false, + }, + }, + output: OutputConfig { + path: PathBuf::from("./src/"), + models: OutputModelConfig { + comment: OutputCommentConfig { + max_width: None, + table_name: true, + column_name: true, + column_db_type: true, + column_rust_type: true, + column_attributes: true, + column_exclude_attributes: Vec::new(), + enable: true, + column_info: true, + ignore_errors: false, + }, + enable: true, + entities: String::from("_entities"), + prelude: true, + path: PathBuf::from("./models"), + }, + }, + templates: None, + templates_dir: None, + } + } +} diff --git a/src/config/output.rs b/src/config/output.rs new file mode 100644 index 0000000..eed6e04 --- /dev/null +++ b/src/config/output.rs @@ -0,0 +1,32 @@ +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct OutputConfig { + pub path: PathBuf, + pub models: OutputModelConfig, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct OutputModelConfig { + pub prelude: bool, + pub enable: bool, + pub path: PathBuf, + pub comment: OutputCommentConfig, + pub entities: String, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct OutputCommentConfig { + pub enable: bool, + pub max_width: Option, + pub table_name: bool, + pub column_info: bool, + pub column_name: bool, + pub column_db_type: bool, + pub column_rust_type: bool, + pub column_attributes: bool, + pub column_exclude_attributes: Vec, + pub ignore_errors: bool, +} diff --git a/src/config/sea_orm_config.rs b/src/config/sea_orm_config.rs new file mode 100644 index 0000000..f2da53c --- /dev/null +++ b/src/config/sea_orm_config.rs @@ -0,0 +1,233 @@ +use serde::{Deserialize, Deserializer, Serialize}; +use serde_yaml::Value; + +use sea_orm_codegen::{ + DateTimeCrate as CodegenDateTimeCrate, EntityWriterContext, WithPrelude, WithSerde, +}; + +use super::Config; +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum EntityFormat { + Expanded, + Compact, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +#[serde(untagged)] +pub enum TableConfig { + Specific { specific: Vec }, + Exclude { exclude: Vec }, +} + +#[derive(Debug, Clone)] +pub enum SerdeEnable { + Both, + Serialize, + Deserialize, + None, +} + +#[derive(Debug, Clone)] +pub enum Prelude { + Enabled, + Disabled, + AllowUnusedImports, +} + +impl<'de> Deserialize<'de> for SerdeEnable { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::String(s) if s == "serialize" => Ok(SerdeEnable::Serialize), + Value::String(s) if s == "deserialize" => Ok(SerdeEnable::Deserialize), + Value::Bool(true) => Ok(SerdeEnable::Both), + Value::Bool(false) => Ok(SerdeEnable::None), + _ => Err(serde::de::Error::custom( + "expected 'serialize', 'deserialize', 'true' or 'false'", + )), + } + } +} +impl Serialize for SerdeEnable { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + SerdeEnable::Both => serializer.serialize_bool(true), + SerdeEnable::Serialize => serializer.serialize_str("serialize"), + SerdeEnable::Deserialize => serializer.serialize_str("deserialize"), + SerdeEnable::None => serializer.serialize_bool(false), + } + } +} +impl<'de> Deserialize<'de> for Prelude { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::Bool(true) => Ok(Prelude::Enabled), + Value::Bool(false) => Ok(Prelude::Disabled), + Value::String(s) if s == "allow_unused_imports" => Ok(Prelude::AllowUnusedImports), + _ => Err(serde::de::Error::custom( + "expected 'true', 'false', or 'allow_unused_imports'", + )), + } + } +} +impl Serialize for Prelude { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Prelude::Enabled => serializer.serialize_bool(true), + Prelude::Disabled => serializer.serialize_bool(false), + Prelude::AllowUnusedImports => serializer.serialize_str("allow_unused_imports"), + } + } +} +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmConfig { + pub prelude: Prelude, + pub serde: SeaOrmSerdeConfig, + pub entity: SeaOrmEntityConfig, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmSerdeConfig { + pub enable: SerdeEnable, + pub skip_deserializing_primary_key: bool, + pub skip_hidden_column: bool, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmEntityConfig { + pub format: EntityFormat, + pub tables: SeaOrmTableConfig, + pub extra_derives: SeaOrmExtraDerivesConfig, + pub extra_attributes: SeaOrmExtraAttributesConfig, + pub date_time_crate: DateTimeCrate, + pub with_copy_enums: bool, +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmTableConfig { + pub include_hidden: bool, + pub skip_seaql_migrations: bool, + #[serde(flatten)] + pub table_config: Option, +} +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmExtraDerivesConfig { + pub model: Vec, + #[serde(rename = "enum")] + pub eenum: Vec, +} +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct SeaOrmExtraAttributesConfig { + pub model: Vec, + #[serde(rename = "enum")] + pub eenum: Vec, +} +#[derive(Deserialize, Serialize, Debug, Clone)] +#[serde(rename_all = "lowercase")] +pub enum DateTimeCrate { + Time, + Chrono, +} + +impl From for CodegenDateTimeCrate { + fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate { + match date_time_crate { + DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono, + DateTimeCrate::Time => CodegenDateTimeCrate::Time, + } + } +} + +impl SeaOrmTableConfig { + pub fn get_filter(&self) -> Box bool> { + let include_hidden = self.include_hidden; + if let Some(table) = &self.table_config { + match table { + TableConfig::Specific { specific } => { + let specific = specific.clone(); + Box::new(move |table: &String| { + (include_hidden || !table.starts_with('_')) && specific.contains(table) + }) + } + TableConfig::Exclude { exclude } => { + let exclude = exclude.clone(); + Box::new(move |table: &String| { + (include_hidden || !table.starts_with('_')) && !exclude.contains(table) + }) + } + } + } else if self.skip_seaql_migrations { + Box::new(move |table: &String| { + (include_hidden || !table.starts_with('_')) + && !table.starts_with("seaql_migrations") + }) + } else { + Box::new(move |table: &String| (include_hidden || !table.starts_with('_'))) + } + } +} + +impl EntityFormat { + pub fn is_expanded(&self) -> bool { + matches!(self, EntityFormat::Expanded) + } +} +impl From for WithPrelude { + fn from(val: Prelude) -> Self { + match val { + Prelude::Enabled => WithPrelude::All, + + Prelude::Disabled => WithPrelude::None, + Prelude::AllowUnusedImports => WithPrelude::AllAllowUnusedImports, + } + } +} +impl From for WithSerde { + fn from(val: SerdeEnable) -> Self { + match val { + SerdeEnable::Both => WithSerde::Both, + SerdeEnable::Serialize => WithSerde::Serialize, + SerdeEnable::Deserialize => WithSerde::Deserialize, + SerdeEnable::None => WithSerde::None, + } + } +} + +impl From for EntityWriterContext { + fn from(val: Config) -> Self { + EntityWriterContext::new( + val.sea_orm.entity.format.is_expanded(), + val.sea_orm.prelude.into(), + val.sea_orm.serde.enable.into(), + val.sea_orm.entity.with_copy_enums, + val.sea_orm.entity.date_time_crate.into(), + val.db.database_schema, + false, + val.sea_orm.serde.skip_deserializing_primary_key, + val.sea_orm.serde.skip_hidden_column, + val.sea_orm.entity.extra_derives.model, + val.sea_orm.entity.extra_attributes.model, + val.sea_orm.entity.extra_derives.eenum, + val.sea_orm.entity.extra_attributes.eenum, + false, + false, + ) + } +} diff --git a/src/config/template.rs b/src/config/template.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/config/template.rs @@ -0,0 +1 @@ + diff --git a/src/generator/column.rs b/src/generator/column.rs deleted file mode 100644 index cfbfd9d..0000000 --- a/src/generator/column.rs +++ /dev/null @@ -1,32 +0,0 @@ -use color_eyre::{eyre::ContextCompat, Result}; -use sea_schema::sea_query::{ColumnDef, ColumnSpec, ColumnType, IndexCreateStatement}; -#[derive(Clone, Debug)] -pub struct Column { - pub name: String, - pub col_type: ColumnType, - pub attrs: Vec, -} - -impl Column { - pub fn new(column: ColumnDef, index: Option) -> Result { - let name = column.get_column_name(); - let col_type = column - .get_column_type() - .context("Unable to get column type")? - .clone(); - let mut attrs = column.get_column_spec().clone(); - if let Some(index) = index { - if index.is_unique_key() { - attrs.push(ColumnSpec::UniqueKey) - } - if index.is_primary_key() { - attrs.push(ColumnSpec::PrimaryKey); - } - } - Ok(Column { - name: name.to_string(), - col_type, - attrs: attrs.to_vec(), - }) - } -} diff --git a/src/generator/file.rs b/src/generator/file.rs index f0e1610..5d16544 100644 --- a/src/generator/file.rs +++ b/src/generator/file.rs @@ -1,95 +1,38 @@ -use std::path::PathBuf; - -use crate::config::Config; - -use super::table::Table; use color_eyre::Result; -use handlebars::Handlebars; -use prettytable::{format, row, Table as PTable}; -use sea_orm_codegen::OutputFile; -use serde::Serialize; -use tokio::fs; -const HEADER: &str = r#"== Schema Information"#; -const COMMENTHEAD: &str = r#"/*"#; -const COMMENTBODY: &str = r#" *"#; -const COMMENTTAIL: &str = r#"*/"#; +use std::{collections::HashMap, path::PathBuf}; #[derive(Debug, Clone)] -pub struct FileGenerator { - filename: String, - table: Table, -} -#[derive(Debug, Clone, Serialize)] -pub struct FileContext { - has_prelude: bool, - prelude: String, +pub struct GeneratedFileChunk { + pub path: PathBuf, + pub content: String, + pub priority: i32, } -impl FileGenerator { - pub fn new(table: Table) -> Result { - let filename = format!("{}.rs", table.name); - Ok(FileGenerator { table, filename }) - } - pub async fn build_file<'a>( - &self, - config: &Config, - handlebars: &'a Handlebars<'a>, - ) -> Result { - let filepath = config.output.path.join(&self.filename); - Ok(OutputFile { - name: filepath.to_str().unwrap().to_string(), - content: self.generate_file(config, handlebars).await?, - }) - } - pub async fn generate_file<'a>( - &self, - config: &Config, - handlebars: &'a Handlebars<'a>, - ) -> Result { - let filepath = config.output.path.join(&self.filename); - let file_context = FileContext { - has_prelude: false, - prelude: String::new(), - }; - let generated_header = self.generate_header(config).await?; - if filepath.exists() { - let mut file_content = fs::read_to_string(filepath).await?; - Ok(file_content) +#[derive(Debug, Clone)] +pub struct GeneratedFile { + pub path: PathBuf, + pub content: String, +} + +pub fn combine_chunks(chunks: Vec) -> Result> { + let mut table: HashMap> = HashMap::new(); + for chunk in chunks { + let path = chunk.path.clone(); + if let Some(v) = table.get_mut(&path) { + v.push(chunk); } else { - let content = handlebars.render("model", &file_context)?; - Ok(format!("{}{}", generated_header, content)) + table.insert(path, vec![chunk]); } } - pub async fn generate_header(&self, config: &Config) -> Result { - let mut column_info_table = PTable::new(); - let format = format::FormatBuilder::new() - .column_separator(' ') - .borders(' ') - .separators( - &[ - format::LinePosition::Bottom, - format::LinePosition::Title, - // format::LinePosition::Top, - ], - format::LineSeparator::default(), - ) - .padding(1, 1) - .build(); - column_info_table.set_format(format); - column_info_table.set_titles(row!["Name", "Type", "RustType", "Attributes"]); - // let indexes = table.get_indexes(); - // tracing::info!(?indexes); - // for column in table.get_columns() { - // let name = column.get_column_name(); - // if let Some(column_type) = column.get_column_type() { - // let column_type_rust = - // type_to_rust_string(column_type, config.sea_orm.entity.date_time_crate.clone()); - // let column_type = - // type_to_string(column_type, config.sea_orm.entity.date_time_crate.clone()); - // let attrs = attrs_to_string(column.get_column_spec()); - // ptable.add_row(row![name, column_type, column_type_rust, attrs]); - // } - // } - Ok(String::new()) + + let mut files = Vec::new(); + for (path, mut chunks) in table { + chunks.sort_by(|a, b| a.priority.cmp(&b.priority)); + let mut content = String::new(); + for chunk in chunks { + content.push_str(&chunk.content); + } + files.push(GeneratedFile { path, content }); } + Ok(files) } diff --git a/src/generator/mod.rs b/src/generator/mod.rs index efe40d9..b6611ae 100644 --- a/src/generator/mod.rs +++ b/src/generator/mod.rs @@ -1,42 +1,82 @@ -pub mod column; -pub mod discover; +use file::{GeneratedFile, GeneratedFileChunk}; + +// pub mod column; +// pub mod discover; pub mod file; -pub mod table; - -use core::time; - -use crate::{config::DateTimeCrate, Config}; +pub mod models; +// pub mod table; +// +// use core::time; +// +// use crate::{config::DateTimeCrate, Config}; use color_eyre::{ eyre::{eyre, ContextCompat, Report}, Result, }; -use comment_parser::{CommentParser, Event}; use handlebars::Handlebars; -use prettytable::{format, row}; -use sea_orm_codegen::OutputFile; -use sea_schema::sea_query::{self, ColumnSpec, ColumnType, StringLen, TableCreateStatement}; -use table::Table; -use tokio::{fs, task::JoinSet}; -use url::Url; +use models::ModelConfig; -pub async fn update_files<'a>( - tables: Vec, - config: Config, +use crate::config::Config; +// use comment_parser::{CommentParser, Event}; +// use discover::DbType; +// use handlebars::Handlebars; +// use sea_orm_codegen::OutputFile; +// use sea_schema::sea_query::{self, ColumnSpec, ColumnType, StringLen, TableCreateStatement}; +// use table::Table; +// use tokio::{fs, task::JoinSet}; +// use url::Url; + +pub async fn generate<'a>( + database_url: &str, + config: &Config, handlebars: &'a Handlebars<'a>, -) -> Result> { - let tables = tables - .into_iter() - .map(|table| Table::new(table.clone(), config.clone())) - .collect::>>()?; +) -> Result> { let mut files = Vec::new(); - for table in tables { - let comment = file::FileGenerator::new(table)?; - let file = comment.build_file(&config, &handlebars).await?; - files.push(file); - } - Ok(Vec::new()) + let model_outputs = models::generate_models(database_url, config, handlebars).await?; + files.extend(model_outputs); + Ok(files) } +// pub async fn update_files<'a>( +// tables: Vec, +// config: Config, +// handlebars: &'a Handlebars<'a>, +// db_type: DbType, +// ) -> Result> { +// let mut files = Vec::new(); +// let entities_path = &config.output.models.entities.replace("/", "::"); +// let entities_path_split = entities_path.split("::").collect::>(); +// +// let mut mod_file_content = String::new(); +// mod_file_content +// .push_str(format!("pub mod {};\n", entities_path_split.first().unwrap_or(&"")).as_str()); +// for table in tables { +// mod_file_content.push_str(format!("pub mod {};\n", table.name).as_str()); +// let entities_path = format!("super::{}", entities_path); +// let comment = file::FileGenerator::new(table, entities_path)?; +// let file = comment.build_file(&config, handlebars, &db_type).await?; +// tracing::info!(?file.name, file.content.len = file.content.len()); +// files.push(file); +// } +// if entities_path_split.len() > 1 { +// for index in 0..entities_path_split.len() - 1 { +// let entity = entities_path_split[index]; +// let next = entities_path_split[index + 1]; +// files.push(OutputFile { +// name: format!("{}/mod.rs", entity), +// content: format!("pub mod {};\n", next), +// }); +// } +// } +// +// files.push(OutputFile { +// name: "mod.rs".to_string(), +// content: mod_file_content, +// }); +// +// Ok(files) +// } +// // pub async fn generate_models( // tables: Vec, // config: Config, diff --git a/src/generator/models/column.rs b/src/generator/models/column.rs new file mode 100644 index 0000000..ba884d6 --- /dev/null +++ b/src/generator/models/column.rs @@ -0,0 +1,236 @@ +use color_eyre::{eyre::ContextCompat, Result}; +use comfy_table::Cell; +use heck::ToUpperCamelCase; +use sea_schema::sea_query::{ColumnDef, ColumnSpec, ColumnType, IndexCreateStatement}; + +use crate::config::{sea_orm_config::DateTimeCrate, Config}; + +use super::{discover::DbType, ModelConfig}; +#[derive(Clone, Debug)] +pub struct Column { + pub name: String, + pub col_type: ColumnType, + pub attrs: Vec, +} + +impl Column { + pub fn new(column: ColumnDef, index: Option) -> Result { + let name = column.get_column_name(); + let col_type = column + .get_column_type() + .context("Unable to get column type")? + .clone(); + let mut attrs = column.get_column_spec().clone(); + if let Some(index) = index { + if index.is_unique_key() { + attrs.push(ColumnSpec::UniqueKey) + } + if index.is_primary_key() { + attrs.push(ColumnSpec::PrimaryKey); + } + } + Ok(Column { + name: name.to_string(), + col_type, + attrs: attrs.to_vec(), + }) + } + pub fn get_info_row(&self, config: &ModelConfig) -> Result> { + let column_type_rust = self.get_rs_type(&config.comment.date_time_crate); + let column_type = self.get_db_type(&config.db_type); + let attrs = self.attrs_to_string(); + let mut cols = Vec::new(); + if config.comment.column_name { + cols.push(Cell::new(self.name.clone())) + } + if config.comment.column_name { + cols.push(Cell::new(column_type.clone())) + } + if config.comment.column_rust_type { + cols.push(Cell::new(column_type_rust.clone())) + } + if config.comment.column_attributes { + cols.push(Cell::new(attrs.clone())); + } + Ok(cols) + } + pub fn attrs_to_string(&self) -> String { + self.attrs + .iter() + .filter_map(Self::get_addr_type) + .map(|s| s.to_string()) + .collect::>() + .join(", ") + } + pub fn get_addr_type(attr: &ColumnSpec) -> Option { + match attr { + ColumnSpec::PrimaryKey => Some("primary key".to_owned()), + ColumnSpec::Null => todo!(), + ColumnSpec::NotNull => Some("not null".to_owned()), + ColumnSpec::Default(simple_expr) => todo!(), + ColumnSpec::AutoIncrement => Some("autoincrement".to_owned()), + ColumnSpec::UniqueKey => Some("unique key".to_owned()), + ColumnSpec::Check(simple_expr) => todo!(), + ColumnSpec::Generated { expr, stored } => todo!(), + ColumnSpec::Extra(_) => todo!(), + ColumnSpec::Comment(_) => todo!(), + } + } + pub fn get_db_type(&self, db_type: &DbType) -> String { + fn write_db_type(col_type: &ColumnType, db_type: &DbType) -> String { + #[allow(unreachable_patterns)] + match (col_type, db_type) { + (ColumnType::Char(_), _) => "char".to_owned(), + (ColumnType::String(_), _) => "varchar".to_owned(), + (ColumnType::Text, _) => "text".to_owned(), + (ColumnType::TinyInteger, DbType::MySql | DbType::Sqlite) => "tinyint".to_owned(), + (ColumnType::TinyInteger, DbType::Postgres) => "smallint".to_owned(), + (ColumnType::SmallInteger, _) => "smallint".to_owned(), + (ColumnType::Integer, DbType::MySql) => "int".to_owned(), + (ColumnType::Integer, _) => "integer".to_owned(), + (ColumnType::BigInteger, DbType::MySql | DbType::Postgres) => "bigint".to_owned(), + (ColumnType::BigInteger, DbType::Sqlite) => "integer".to_owned(), + (ColumnType::TinyUnsigned, DbType::MySql) => "tinyint unsigned".to_owned(), + (ColumnType::TinyUnsigned, DbType::Postgres) => "smallint".to_owned(), + (ColumnType::TinyUnsigned, DbType::Sqlite) => "tinyint".to_owned(), + (ColumnType::SmallUnsigned, DbType::MySql) => "smallint unsigned".to_owned(), + (ColumnType::SmallUnsigned, DbType::Postgres | DbType::Sqlite) => { + "smallint".to_owned() + } + (ColumnType::Unsigned, DbType::MySql) => "int unsigned".to_owned(), + (ColumnType::Unsigned, DbType::Postgres | DbType::Sqlite) => "integer".to_owned(), + (ColumnType::BigUnsigned, DbType::MySql) => "bigint unsigned".to_owned(), + (ColumnType::BigUnsigned, DbType::Postgres) => "bigint".to_owned(), + (ColumnType::BigUnsigned, DbType::Sqlite) => "integer".to_owned(), + (ColumnType::Float, DbType::MySql | DbType::Sqlite) => "float".to_owned(), + (ColumnType::Float, DbType::Postgres) => "real".to_owned(), + (ColumnType::Double, DbType::MySql | DbType::Sqlite) => "double".to_owned(), + (ColumnType::Double, DbType::Postgres) => "double precision".to_owned(), + (ColumnType::Decimal(_), DbType::MySql | DbType::Postgres) => "decimal".to_owned(), + (ColumnType::Decimal(_), DbType::Sqlite) => "real".to_owned(), + (ColumnType::DateTime, DbType::MySql) => "datetime".to_owned(), + (ColumnType::DateTime, DbType::Postgres) => "timestamp w/o tz".to_owned(), + (ColumnType::DateTime, DbType::Sqlite) => "datetime_text".to_owned(), + (ColumnType::Timestamp, DbType::MySql | DbType::Postgres) => "timestamp".to_owned(), + (ColumnType::Timestamp, DbType::Sqlite) => "timestamp_text".to_owned(), + (ColumnType::TimestampWithTimeZone, DbType::MySql) => "timestamp".to_owned(), + (ColumnType::TimestampWithTimeZone, DbType::Postgres) => { + "timestamp w tz".to_owned() + } + (ColumnType::TimestampWithTimeZone, DbType::Sqlite) => { + "timestamp_with_timezone_text".to_owned() + } + (ColumnType::Time, DbType::MySql | DbType::Postgres) => "time".to_owned(), + (ColumnType::Time, DbType::Sqlite) => "time_text".to_owned(), + (ColumnType::Date, DbType::MySql | DbType::Postgres) => "date".to_owned(), + (ColumnType::Date, DbType::Sqlite) => "date_text".to_owned(), + (ColumnType::Year, DbType::MySql) => "year".to_owned(), + (ColumnType::Interval(_, _), DbType::Postgres) => "interval".to_owned(), + (ColumnType::Blob, DbType::MySql | DbType::Sqlite) => "blob".to_owned(), + (ColumnType::Blob, DbType::Postgres) => "bytea".to_owned(), + (ColumnType::Binary(_), DbType::MySql) => "binary".to_owned(), + (ColumnType::Binary(_), DbType::Postgres) => "bytea".to_owned(), + (ColumnType::Binary(_), DbType::Sqlite) => "blob".to_owned(), + (ColumnType::VarBinary(_), DbType::MySql) => "varbinary".to_owned(), + (ColumnType::VarBinary(_), DbType::Postgres) => "bytea".to_owned(), + (ColumnType::VarBinary(_), DbType::Sqlite) => "varbinary_blob".to_owned(), + (ColumnType::Bit(_), DbType::MySql | DbType::Postgres) => "bit".to_owned(), + (ColumnType::VarBit(_), DbType::MySql) => "bit".to_owned(), + (ColumnType::VarBit(_), DbType::Postgres) => "varbit".to_owned(), + (ColumnType::Boolean, DbType::MySql | DbType::Postgres) => "bool".to_owned(), + (ColumnType::Boolean, DbType::Sqlite) => "boolean".to_owned(), + (ColumnType::Money(_), DbType::MySql) => "decimal".to_owned(), + (ColumnType::Money(_), DbType::Postgres) => "money".to_owned(), + (ColumnType::Money(_), DbType::Sqlite) => "real_money".to_owned(), + (ColumnType::Json, DbType::MySql | DbType::Postgres) => "json".to_owned(), + (ColumnType::Json, DbType::Sqlite) => "json_text".to_owned(), + (ColumnType::JsonBinary, DbType::MySql) => "json".to_owned(), + (ColumnType::JsonBinary, DbType::Postgres) => "jsonb".to_owned(), + (ColumnType::JsonBinary, DbType::Sqlite) => "jsonb_text".to_owned(), + (ColumnType::Uuid, DbType::MySql) => "binary(16)".to_owned(), + (ColumnType::Uuid, DbType::Postgres) => "uuid".to_owned(), + (ColumnType::Uuid, DbType::Sqlite) => "uuid_text".to_owned(), + (ColumnType::Enum { name, .. }, DbType::MySql) => { + format!("ENUM({})", name.to_string().to_upper_camel_case()) + } + (ColumnType::Enum { name, .. }, DbType::Postgres) => { + name.to_string().to_uppercase() + } + (ColumnType::Enum { .. }, DbType::Sqlite) => "enum_text".to_owned(), + (ColumnType::Array(column_type), DbType::Postgres) => { + format!("{}[]", write_db_type(column_type, db_type)).to_uppercase() + } + (ColumnType::Vector(_), DbType::Postgres) => "vector".to_owned(), + (ColumnType::Cidr, DbType::Postgres) => "cidr".to_owned(), + (ColumnType::Inet, DbType::Postgres) => "inet".to_owned(), + (ColumnType::MacAddr, DbType::Postgres) => "macaddr".to_owned(), + (ColumnType::LTree, DbType::Postgres) => "ltree".to_owned(), + + _ => unimplemented!(), + } + } + write_db_type(&self.col_type, db_type) + } + pub fn get_rs_type(&self, date_time_crate: &DateTimeCrate) -> String { + fn write_rs_type(col_type: &ColumnType, date_time_crate: &DateTimeCrate) -> String { + #[allow(unreachable_patterns)] + match col_type { + ColumnType::Char(_) + | ColumnType::String(_) + | ColumnType::Text + | ColumnType::Custom(_) => "String".to_owned(), + ColumnType::TinyInteger => "i8".to_owned(), + ColumnType::SmallInteger => "i16".to_owned(), + ColumnType::Integer => "i32".to_owned(), + ColumnType::BigInteger => "i64".to_owned(), + ColumnType::TinyUnsigned => "u8".to_owned(), + ColumnType::SmallUnsigned => "u16".to_owned(), + ColumnType::Unsigned => "u32".to_owned(), + ColumnType::BigUnsigned => "u64".to_owned(), + ColumnType::Float => "f32".to_owned(), + ColumnType::Double => "f64".to_owned(), + ColumnType::Json | ColumnType::JsonBinary => "Json".to_owned(), + ColumnType::Date => match date_time_crate { + DateTimeCrate::Chrono => "Date".to_owned(), + DateTimeCrate::Time => "TimeDate".to_owned(), + }, + ColumnType::Time => match date_time_crate { + DateTimeCrate::Chrono => "Time".to_owned(), + DateTimeCrate::Time => "TimeTime".to_owned(), + }, + ColumnType::DateTime => match date_time_crate { + DateTimeCrate::Chrono => "DateTime".to_owned(), + DateTimeCrate::Time => "TimeDateTime".to_owned(), + }, + ColumnType::Timestamp => match date_time_crate { + DateTimeCrate::Chrono => "DateTimeUtc".to_owned(), + DateTimeCrate::Time => "TimeDateTime".to_owned(), + }, + ColumnType::TimestampWithTimeZone => match date_time_crate { + DateTimeCrate::Chrono => "DateTimeWithTimeZone".to_owned(), + DateTimeCrate::Time => "TimeDateTimeWithTimeZone".to_owned(), + }, + ColumnType::Decimal(_) | ColumnType::Money(_) => "Decimal".to_owned(), + ColumnType::Uuid => "Uuid".to_owned(), + ColumnType::Binary(_) | ColumnType::VarBinary(_) | ColumnType::Blob => { + "Vec".to_owned() + } + ColumnType::Boolean => "bool".to_owned(), + ColumnType::Enum { name, .. } => name.to_string().to_upper_camel_case(), + ColumnType::Array(column_type) => { + format!("Vec<{}>", write_rs_type(column_type, date_time_crate)) + } + ColumnType::Vector(_) => "::pgvector::Vector".to_owned(), + ColumnType::Bit(None | Some(1)) => "bool".to_owned(), + ColumnType::Bit(_) | ColumnType::VarBit(_) => "Vec".to_owned(), + ColumnType::Year => "i32".to_owned(), + ColumnType::Cidr | ColumnType::Inet => "IpNetwork".to_owned(), + ColumnType::Interval(_, _) | ColumnType::MacAddr | ColumnType::LTree => { + "String".to_owned() + } + _ => unimplemented!(), + } + } + write_rs_type(&self.col_type, date_time_crate) + } +} diff --git a/src/generator/models/comment.rs b/src/generator/models/comment.rs new file mode 100644 index 0000000..54098da --- /dev/null +++ b/src/generator/models/comment.rs @@ -0,0 +1,154 @@ +use crate::generator::models::{CommentConfig, CommentConfigSerde}; + +use super::{table::Table, ModelConfig}; +use color_eyre::{eyre, owo_colors::colors::White, Result}; +use comfy_table::{ContentArrangement, Table as CTable}; +use comment_parser::{CommentParser, Event}; + +const HEADER: &str = r#"== Schema Information"#; +const COMMENTHEAD: &str = r#"/*"#; +const COMMENTBODY: &str = r#" *"#; +const COMMENTTAIL: &str = r#"*/"#; +const SETTINGSDELIMITER: &str = r#"```"#; + +pub struct ModelCommentGenerator {} + +impl ModelCommentGenerator { + pub fn find_settings_block(file_content: &str) -> Option { + let delimiter_length = SETTINGSDELIMITER.len(); + let start_pos = file_content.find(SETTINGSDELIMITER)?; + let end_pos = file_content[start_pos + delimiter_length..].find(SETTINGSDELIMITER)?; + let content = &file_content[start_pos + delimiter_length..start_pos + end_pos]; + let content = content.replace(&format!("\n{COMMENTBODY}"), "\n"); + Some(content) + } + pub fn generate_comment( + table: Table, + file_content: &str, + config: &ModelConfig, + ) -> Result { + let rules = comment_parser::get_syntax("rust").unwrap(); + let parser = CommentParser::new(&file_content, rules); + for comment in parser { + if let Event::BlockComment(body, _) = comment { + if body.contains(HEADER) { + tracing::debug!("Found header"); + let mut settings = config.comment.clone(); + let mut new_settings = None; + if let Some(parsed_settings) = Self::find_settings_block(file_content) { + tracing::info!(?new_settings); + match serde_yaml::from_str::(&parsed_settings) { + Ok(s) => { + new_settings = Some(s.clone()); + settings = s.merge(&settings); + tracing::info!(?settings); + } + Err(e) => { + if !settings.ignore_errors { + return Err(e.into()); + } + } + } + } + tracing::debug!(?table.name, ?settings); + if settings.enable { + let comment = + Self::generate_comment_content(table, config, &settings, new_settings)?; + return Ok(file_content.replace(body, &comment)); + } + + // let settings = settings.unwrap(); + // tracing::info!(?settings); + // let merged_settings = settings.merge(&config.comment); + // if merged_settings.enable { + } + } + } + + let comment = Self::generate_comment_content(table, config, &config.comment, None)?; + Ok(format!("{}\n{}", comment, file_content)) + } + pub fn generate_comment_content( + table: Table, + model_config: &ModelConfig, + config: &CommentConfig, + parsed_settings: Option, + ) -> Result { + let mut model_config = model_config.clone(); + model_config.comment = config.clone(); + let column_info_table = if config.column_info { + let mut column_info_table = CTable::new(); + let mut header = Vec::new(); + if config.column_name { + header.push("Name"); + } + if config.column_db_type { + header.push("DbType"); + } + if config.column_rust_type { + header.push("RsType"); + } + if config.column_attributes { + header.push("Attrs"); + } + column_info_table + .load_preset(" -+=++ + ++") + .set_content_arrangement(ContentArrangement::Dynamic) + .set_header(header); + if let Some(width) = config.max_width { + column_info_table.set_width(width); + } + for column in &table.columns { + column_info_table.add_row(column.get_info_row(&model_config)?); + } + column_info_table.to_string() + } else { + String::new() + }; + let config_part = match parsed_settings { + Some(settings) => { + let settings_str = serde_yaml::to_string(&settings)?; + let settings_str = settings_str + .lines() + .map(|line| format!(" {}", line)) + .collect::>() + .join("\n"); + format!( + "{SETTINGSDELIMITER}\n{}\n{SETTINGSDELIMITER}\n\n", + settings_str + ) + } + None => String::new(), + }; + + let table_name = &table.name; + let table_name_str = if config.table_name { + format!("Table: {}\n", table_name) + } else { + String::new() + }; + let string = format!("{HEADER}\n{config_part}{table_name_str}\n{column_info_table}"); + + let padded_string = Self::pad_comment(&string); + Ok(padded_string) + } + + pub fn pad_comment(s: &str) -> String { + let parts = s.split('\n').collect::>(); + let mut padded = String::new(); + for (index, part) in parts.iter().enumerate() { + let first = index == 0; + let comment = match first { + true => COMMENTHEAD.to_string(), + false => COMMENTBODY.to_string(), + }; + let padded_part = format!("{} {}\n", comment, part); + padded.push_str(&padded_part); + } + padded.push_str(COMMENTTAIL); + padded + } + // pub async fn generate_header(&self, config: &Config, db_type: &DbType) -> Result { + // + // } +} diff --git a/src/generator/discover.rs b/src/generator/models/discover.rs similarity index 80% rename from src/generator/discover.rs rename to src/generator/models/discover.rs index ea494b9..c49ff44 100644 --- a/src/generator/discover.rs +++ b/src/generator/models/discover.rs @@ -4,17 +4,24 @@ use color_eyre::eyre::{eyre, Context, ContextCompat, Report, Result}; use sea_schema::sea_query::TableCreateStatement; use url::Url; -use crate::config::Config; +use crate::config::db::DbConfig; +#[derive(Debug, Clone)] +pub enum DbType { + MySql, + Postgres, + Sqlite, +} + pub async fn get_tables( database_url: String, - config: &Config, -) -> Result<(Option, Vec)> { + filter: Box bool>, + database_config: &DbConfig, +) -> Result<(Vec, DbType)> { let url = Url::parse(&database_url)?; tracing::trace!(?url); let is_sqlite = url.scheme() == "sqlite"; - let filter_tables = config.sea_orm.entity.tables.get_filter(); let database_name: &str = (if !is_sqlite { let database_name = url @@ -31,15 +38,15 @@ pub async fn get_tables( Ok(Default::default()) })?; - let (schema_name, table_stmts) = match url.scheme() { + let (table_stmts, db_type) = match url.scheme() { "mysql" => { use sea_schema::mysql::discovery::SchemaDiscovery; use sqlx::MySql; tracing::info!("Connecting to MySQL"); let connection = sqlx_connect::( - config.db.max_connections, - config.db.acquire_timeout, + database_config.max_connections, + database_config.acquire_timeout, url.as_str(), None, ) @@ -51,12 +58,12 @@ pub async fn get_tables( let table_stmts = schema .tables .into_iter() - .filter(|schema| filter_tables(&schema.info.name)) + .filter(|schema| filter(&schema.info.name)) // .filter(|schema| filter_hidden_tables(&schema.info.name)) // .filter(|schema| filter_skip_tables(&schema.info.name)) .map(|schema| schema.write()) .collect(); - (None, table_stmts) + (table_stmts, DbType::MySql) } "sqlite" => { use sea_schema::sqlite::discovery::SchemaDiscovery; @@ -64,8 +71,8 @@ pub async fn get_tables( tracing::info!("Connecting to SQLite"); let connection = sqlx_connect::( - config.db.max_connections, - config.db.acquire_timeout, + database_config.max_connections, + database_config.acquire_timeout, url.as_str(), None, ) @@ -80,22 +87,25 @@ pub async fn get_tables( let table_stmts = schema .tables .into_iter() - .filter(|schema| filter_tables(&schema.name)) + .filter(|schema| filter(&schema.name)) // .filter(|schema| filter_hidden_tables(&schema.name)) // .filter(|schema| filter_skip_tables(&schema.name)) .map(|schema| schema.write()) .collect(); - (None, table_stmts) + (table_stmts, DbType::Sqlite) } "postgres" | "potgresql" => { use sea_schema::postgres::discovery::SchemaDiscovery; use sqlx::Postgres; tracing::info!("Connecting to Postgres"); - let schema = &config.db.database_schema.as_deref().unwrap_or("public"); + let schema = &database_config + .database_schema + .as_deref() + .unwrap_or("public"); let connection = sqlx_connect::( - config.db.max_connections, - config.db.acquire_timeout, + database_config.max_connections, + database_config.acquire_timeout, url.as_str(), Some(schema), ) @@ -107,18 +117,18 @@ pub async fn get_tables( let table_stmts = schema .tables .into_iter() - .filter(|schema| filter_tables(&schema.info.name)) + .filter(|schema| filter(&schema.info.name)) // .filter(|schema| filter_hidden_tables(&schema.info.name)) // .filter(|schema| filter_skip_tables(&schema.info.name)) .map(|schema| schema.write()) .collect(); - (config.db.database_schema.clone(), table_stmts) + (table_stmts, DbType::Postgres) } _ => unimplemented!("{} is not supported", url.scheme()), }; tracing::info!("Schema discovered"); - Ok((schema_name, table_stmts)) + Ok((table_stmts, db_type)) } async fn sqlx_connect( max_connections: u32, diff --git a/src/generator/models/file.rs b/src/generator/models/file.rs new file mode 100644 index 0000000..bea570d --- /dev/null +++ b/src/generator/models/file.rs @@ -0,0 +1,101 @@ +use std::path::PathBuf; + +use crate::{config::Config, generator::file::GeneratedFileChunk}; + +use super::{ + comment::ModelCommentGenerator, discover::DbType, table::Table, CommentConfig, ModelConfig, +}; +use color_eyre::Result; +use comfy_table::{ContentArrangement, Table as CTable}; +use comment_parser::{CommentParser, Event}; +use handlebars::Handlebars; +use sea_orm_codegen::OutputFile; +use serde::Serialize; +use tokio::fs; +#[derive(Debug, Clone)] +pub struct FileGenerator { + filename: String, + table: Table, +} +// #[derive(Debug, Clone, Serialize)] +// pub struct FileContext { +// entities_path: String, +// model_path: Option, +// model_name: Option, +// active_model_name: Option, +// prelude_path: Option, +// } + +impl FileGenerator { + pub async fn generate_file<'a>( + table: Table, + config: &ModelConfig, + handlebars: &'a Handlebars<'a>, + ) -> Result> { + let mut file_chunks = Vec::new(); + file_chunks.push(GeneratedFileChunk { + path: config.models_path.join("mod.rs"), + content: format!("pub mod {};", table.name), + priority: 0, + }); + let filepath = config.models_path.join(format!("{}.rs", table.name)); + tracing::debug!(?filepath, "Generating file"); + if filepath.exists() { + file_chunks + .extend(Self::handle_existing_file(table, &filepath, config, handlebars).await?); + } else { + } + + // let filepath = config.output.path.join(&self.filename); + // let file_context = FileContext { + // entities_path: self.entities_path.clone(), + // model_name: self.table.name.clone(), + // }; + // let generated_header = self.generate_header(config, db_type).await?; + // if filepath.exists() { + // let mut file_content = fs::read_to_string(filepath).await?; + // if !config.output.models.comment.enable { + // return Ok(file_content); + // } + // let rules = comment_parser::get_syntax("rust").unwrap(); + // let parser = CommentParser::new(&file_content, rules); + // for comment in parser { + // if let Event::BlockComment(body, _) = comment { + // if body.contains(HEADER) { + // tracing::debug!("Found header"); + // file_content = file_content.replace(body, &generated_header); + // return Ok(file_content); + // } + // } + // } + // Ok(format!("{}\n{}", generated_header, file_content)) + // } else { + // let content = handlebars.render("model", &file_context)?; + // Ok(format!("{}{}", generated_header, content)) + // } + + // Ok(OutputFile { + // name: self.filename.clone(), + // content: self.generate_file(config, handlebars, db_type).await?, + // }) + Ok(file_chunks) + } + async fn handle_existing_file<'a>( + table: Table, + filepath: &PathBuf, + config: &ModelConfig, + handlebars: &'a Handlebars<'a>, + ) -> Result> { + let mut file_chunks = Vec::new(); + let mut file_content = fs::read_to_string(filepath).await?; + if config.comment.enable { + file_content = ModelCommentGenerator::generate_comment(table, &file_content, config)?; + } + file_chunks.push(GeneratedFileChunk { + path: filepath.clone(), + content: file_content, + priority: 0, + }); + Ok(file_chunks) + } +} diff --git a/src/generator/models/mod.rs b/src/generator/models/mod.rs new file mode 100644 index 0000000..619be39 --- /dev/null +++ b/src/generator/models/mod.rs @@ -0,0 +1,153 @@ +use crate::config::{sea_orm_config::DateTimeCrate, Config}; +use color_eyre::Result; +use discover::DbType; +use file::FileGenerator; +use handlebars::Handlebars; +use sea_orm_codegen::{EntityTransformer, EntityWriterContext, OutputFile}; +use sea_schema::sea_query::TableCreateStatement; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use table::Table; + +use super::file::{GeneratedFile, GeneratedFileChunk}; +pub mod column; +pub mod comment; +pub mod discover; +pub mod file; +pub mod table; +#[derive(Debug, Clone)] +pub struct ModelConfig { + pub models_path: PathBuf, + pub prelude: bool, + pub entities_path: PathBuf, + pub relative_entities_path: String, + pub enable: bool, + pub comment: CommentConfig, + pub db_type: DbType, +} +#[derive(Debug, Clone)] +pub struct CommentConfig { + pub max_width: Option, + pub enable: bool, + pub table_name: bool, + pub column_info: bool, + pub column_name: bool, + pub column_rust_type: bool, + pub column_db_type: bool, + pub column_attributes: bool, + pub ignore_errors: bool, + pub date_time_crate: DateTimeCrate, +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CommentConfigSerde { + #[serde(skip_serializing_if = "Option::is_none")] + pub max_width: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub enable: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub table_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub info: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub rust_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub db_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub attributes: Option, +} +impl CommentConfigSerde { + pub fn merge(&self, config: &CommentConfig) -> CommentConfig { + CommentConfig { + max_width: self.max_width.or(config.max_width), + table_name: self.table_name.unwrap_or(config.table_name), + column_name: self.name.unwrap_or(config.column_name), + column_info: self.info.unwrap_or(config.column_info), + column_rust_type: self.rust_type.unwrap_or(config.column_rust_type), + column_db_type: self.db_type.unwrap_or(config.column_db_type), + column_attributes: self.attributes.unwrap_or(config.column_attributes), + ignore_errors: config.ignore_errors, + enable: self.enable.unwrap_or(config.enable), + date_time_crate: config.date_time_crate.clone(), + } + } +} + +impl ModelConfig { + pub fn new(config: Config, db_type: DbType) -> Self { + let models_path = config.output.path.join(&config.output.models.path); + let entities_path = models_path.join(&config.output.models.entities); + ModelConfig { + db_type, + prelude: config.output.models.prelude, + entities_path, + models_path, + relative_entities_path: config.output.models.entities.clone(), + enable: config.output.models.enable, + comment: CommentConfig { + max_width: config.output.models.comment.max_width, + enable: config.output.models.comment.enable, + table_name: config.output.models.comment.table_name, + column_name: config.output.models.comment.column_name, + column_info: config.output.models.comment.column_info, + column_rust_type: config.output.models.comment.column_rust_type, + column_db_type: config.output.models.comment.column_db_type, + column_attributes: config.output.models.comment.column_attributes, + ignore_errors: config.output.models.comment.ignore_errors, + date_time_crate: config.sea_orm.entity.date_time_crate, + }, + } + } +} + +pub async fn generate_models<'a>( + database_url: &str, + config: &Config, + handlebars: &'a Handlebars<'a>, +) -> Result> { + let mut files = Vec::new(); + let db_filter = config.sea_orm.entity.tables.get_filter(); + let (table_stmts, db_type) = + discover::get_tables(database_url.to_owned(), db_filter, &config.db).await?; + let model_config = ModelConfig::new(config.clone(), db_type); + + let writer_context = config.clone().into(); + files.extend( + generate_entities(table_stmts.clone(), model_config.clone(), writer_context).await?, + ); + + files.push(GeneratedFileChunk { + path: model_config.models_path.join("mod.rs"), + content: format!("pub mod {};", model_config.relative_entities_path), + priority: 0, + }); + let tables = table_stmts + .into_iter() + .map(Table::new) + .collect::>>()?; + + if model_config.enable { + for table in tables { + files.extend(FileGenerator::generate_file(table, &model_config, handlebars).await?); + } + } + Ok(files) +} + +pub async fn generate_entities( + table_statements: Vec, + config: ModelConfig, + writer_context: EntityWriterContext, +) -> Result> { + let output = EntityTransformer::transform(table_statements)?.generate(&writer_context); + Ok(output + .files + .into_iter() + .map(|OutputFile { name, content }| GeneratedFileChunk { + path: config.entities_path.join(name), + content, + priority: 0, + }) + .collect::>()) +} diff --git a/src/generator/table.rs b/src/generator/models/table.rs similarity index 87% rename from src/generator/table.rs rename to src/generator/models/table.rs index cf18926..684ae3a 100644 --- a/src/generator/table.rs +++ b/src/generator/models/table.rs @@ -11,7 +11,7 @@ pub struct Table { } impl Table { - pub fn new(statement: TableCreateStatement, config: Config) -> Result { + pub fn new(statement: TableCreateStatement) -> Result
{ let table_name = match statement.get_table_name() { Some(table_ref) => match table_ref { sea_query::TableRef::Table(t) @@ -27,6 +27,12 @@ impl Table { tracing::debug!(?table_name); let columns_raw = statement.get_columns(); let indexes = statement.get_indexes(); + for column in columns_raw { + tracing::debug!(?column); + } + for index in indexes { + tracing::debug!(?index); + } let columns = columns_raw .iter() .map(|column| { diff --git a/src/main.rs b/src/main.rs index 00c1e61..ffd7959 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ mod config; mod generator; +mod templates; use clap::Parser; use color_eyre::{eyre::eyre, Report, Result}; @@ -8,6 +9,7 @@ use figment::{ providers::{Format, Serialized, Yaml}, Figment, }; +use handlebars::Handlebars; use indicatif::{ProgressBar, ProgressStyle}; use sea_orm_codegen::{ DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, OutputFile, @@ -37,52 +39,69 @@ async fn main() -> Result<()> { tracing::info!(?args); let mut handlebars = Handlebars::new(); + templates::register_templates(&mut handlebars, &config).await?; - let output_dir = &config.output.path; - let output_internal_entities = output_dir.join("_entities"); - let (_, table_stmts) = generator::discover::get_tables(args.database_url, &config).await?; - let writer_context = config.clone().into(); - let output = EntityTransformer::transform(table_stmts.clone())?.generate(&writer_context); - let mut files = output - .files - .into_iter() - .map(|OutputFile { name, content }| (output_internal_entities.join(name), content)) - .collect::>(); - generator::update_files(table_stmts, config.clone(), &handlebars).await?; - // let generate_files = generator::generate_models(table_stmts, config.clone()) - // .await? - // .into_iter() - // .map(|OutputFile { name, content }| (output_dir.join(name), content)) - // .collect::>(); - // files.extend(generate_files); - tracing::info!("Generated {} files", files.len()); - fs::create_dir_all(&output_internal_entities).await?; - let progress_bar = ProgressBar::new((files.len() * 2) as u64) - .with_style( - ProgressStyle::default_bar() - .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")?, - ) - .with_message("Writing files"); + let outputs = generator::generate(&args.database_url, &config, &handlebars).await?; - for (file_path, content) in files.iter() { - progress_bar.set_message(format!("Writing file: {:?}", file_path)); - tracing::debug!(?file_path, "Writing file"); - let mut file = fs::File::create(&file_path).await?; - file.write_all(content.as_bytes()).await?; - progress_bar.inc(1); + // tracing::info!(?outputs, "Generated files"); + for output in outputs.iter() { + tracing::info!(?output, "Generated chunk"); + // let mut file = fs::File::create(&output.path).await?; + // file.write_all(output.content.as_bytes()).await?; } - progress_bar.set_message("Running rustfmt"); - for (file_path, ..) in files.iter() { - tracing::debug!(?file_path, "Running rustfmt"); - progress_bar.set_message(format!("Running rustfmt: {:?}", file_path)); - let exit_status = Command::new("rustfmt").arg(file_path).status().await?; // Get the status code + let merged_outputs = generator::file::combine_chunks(outputs)?; + for output in merged_outputs.iter() { + tracing::info!(?output, "Merged file"); + let mut file = fs::File::create(&output.path).await?; + file.write_all(output.content.as_bytes()).await?; + } + for output in merged_outputs.iter() { + tracing::info!(?output, "Running rustfmt"); + let exit_status = Command::new("rustfmt").arg(&output.path).status().await?; if !exit_status.success() { - // Propagate the error if any return Err(eyre!("Failed to run rustfmt")); } - progress_bar.inc(1); } - progress_bar.finish(); + + // + // let output_dir = &config.output.path; + // + // let output_models_dir = output_dir.join(&config.output.models.path); + // + // let output_internal_entities = output_models_dir.join(&config.output.models.entities); + // let (_, table_stmts, db_type) = + // generator::discover::get_tables(args.database_url, &config).await?; + // let writer_context = config.clone().into(); + // let output = EntityTransformer::transform(table_stmts.clone())?.generate(&writer_context); + // let mut files = output + // .files + // .into_iter() + // .map(|OutputFile { name, content }| (output_internal_entities.join(name), content)) + // .collect::>(); + // // generator::update_files(table_stmts, config.clone(), &handlebars).await?; + // let generate_files = generator::update_files(table_stmts, config.clone(), &handlebars, db_type) + // .await? + // .into_iter() + // .map(|OutputFile { name, content }| (output_models_dir.join(name), content)) + // .collect::>(); + // files.extend(generate_files); + // tracing::info!("Generated {} files", files.len()); + // fs::create_dir_all(&output_internal_entities).await?; + // + // for (file_path, content) in files.iter() { + // tracing::info!(?file_path, "Writing file"); + // let mut file = fs::File::create(&file_path).await?; + // file.write_all(content.as_bytes()).await?; + // } + // + // for (file_path, ..) in files.iter() { + // tracing::info!(?file_path, "Running rustfmt"); + // let exit_status = Command::new("rustfmt").arg(file_path).status().await?; // Get the status code + // if !exit_status.success() { + // // Propagate the error if any + // return Err(eyre!("Failed to run rustfmt")); + // } + // } Ok(()) } diff --git a/src/templates.rs b/src/templates.rs new file mode 100644 index 0000000..4fe9910 --- /dev/null +++ b/src/templates.rs @@ -0,0 +1,67 @@ +use crate::config::Config; +use color_eyre::eyre::{Context, ContextCompat, Result}; +use handlebars::Handlebars; +use include_dir::{include_dir, Dir, DirEntry, File}; +use serde_yaml::Value; +use std::path::PathBuf; +use tokio::fs; + +static TEMPLATE_DIR: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/templates"); + +async fn handle_direntries<'a>( + entries: &[DirEntry<'a>], + handlebars: &mut Handlebars<'a>, +) -> Result<()> { + for entry in entries.iter().filter(|file| { + file.path() + .extension() + .is_some_and(|f| f.to_str().is_some_and(|f| f == "hbs")) + || file.as_dir().is_some() + }) { + match entry { + DirEntry::File(file) => { + let path = file.path().with_extension(""); + let name = path + .to_str() + .context("Failed to convert path to str")? + .replace("/", "."); + let content = file + .contents_utf8() + .context(format!("Template {} failed to parse", name))?; + + tracing::debug!(?name, "Registering template"); + if !handlebars.has_template(&name) { + handlebars.register_template_string(&name, content)?; + } else { + tracing::debug!(?name, "Template already registered, skipping"); + } + } + DirEntry::Dir(dir) => { + Box::pin(handle_direntries(dir.entries(), handlebars)).await?; + } + } + } + + Ok(()) +} +pub async fn register_templates(handlebars: &mut Handlebars<'_>, config: &Config) -> Result<()> { + if let Some(templates) = &config.templates { + for (name, value) in templates.iter() { + let Value::String(name) = name else { + return Err(color_eyre::eyre::eyre!("Invalid template name")); + }; + let Value::String(path) = value else { + return Err(color_eyre::eyre::eyre!("Invalid template value")); + }; + let mut path = PathBuf::from(path); + if let Some(templates_dir) = &config.templates_dir { + path = templates_dir.join(path); + } + tracing::info!(?name, ?path, "Registering template"); + let content = fs::read_to_string(path).await?; + handlebars.register_template_string(name, content)?; + } + } + handle_direntries(TEMPLATE_DIR.entries(), handlebars).await?; + Ok(()) +} diff --git a/src/templates/model.rs.hbs b/src/templates/model.rs.hbs deleted file mode 100644 index 0b85b03..0000000 --- a/src/templates/model.rs.hbs +++ /dev/null @@ -1,12 +0,0 @@ -use {{entities_path}}::{ActiveMode, Model, Entity}; -{{#if has_prelude}} -use {{prelude}}; -{{/if}} - -impl Model { - -} - -impl ActiveModel { - -} diff --git a/templates/model.hbs b/templates/model.hbs new file mode 100644 index 0000000..2223236 --- /dev/null +++ b/templates/model.hbs @@ -0,0 +1,9 @@ +use {{entities_path}}::{{model_name}}::{ActiveModel, Model, Entity}; +use sea_orm::ActiveModelBehavior; + +#[async_trait::async_trait] +impl ActiveModelBehavior for ActiveModel {} + +impl Model {} + +impl ActiveModel {} diff --git a/templates/modelprelude.hbs b/templates/modelprelude.hbs new file mode 100644 index 0000000..3317116 --- /dev/null +++ b/templates/modelprelude.hbs @@ -0,0 +1,9 @@ +use {{prelude_path}}::*; +use sea_orm::ActiveModelBehavior; + +#[async_trait::async_trait] +impl ActiveModelBehavior for {{active_model_name}} {} + +impl {{model_name}} {} + +impl {{active_model_name}} {}