This commit is contained in:
2025-04-01 21:33:03 +04:00
parent f012a96173
commit dc9695dac0
4 changed files with 335 additions and 126 deletions

View File

@@ -6,11 +6,12 @@ edition = "2021"
[dependencies] [dependencies]
clap = { version = "4.5.32", features = ["derive", "env"] } clap = { version = "4.5.32", features = ["derive", "env"] }
color-eyre = "0.6.3" color-eyre = "0.6.3"
confique = { version = "0.3.0", features = ["yaml", "toml"] } figment = { version = "0.10.19", features = ["yaml"] }
inquire = "0.7.5" inquire = "0.7.5"
sea-orm-codegen = "1.1.8" sea-orm-codegen = "1.1.8"
sea-schema = { version = "0.16.1", features = ["sqlx-all"] } sea-schema = { version = "0.16.1", features = ["sqlx-all"] }
serde = { version = "1.0.219", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] }
serde_yaml = "0.9.34"
sqlx = { version = "0.8.3", features = ["mysql", "postgres", "sqlite"] } sqlx = { version = "0.8.3", features = ["mysql", "postgres", "sqlite"] }
syn = { version = "2.0.100", features = ["extra-traits", "full"] } syn = { version = "2.0.100", features = ["extra-traits", "full"] }
tokio = { version = "1.44.1", features = ["full"] } tokio = { version = "1.44.1", features = ["full"] }

224
src/config.rs Normal file
View File

@@ -0,0 +1,224 @@
use sea_orm_codegen::DateTimeCrate as CodegenDateTimeCrate;
use serde::{Deserialize, Deserializer, Serialize};
use serde_yaml::Value;
use tracing::info;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EntityFormat {
Expanded,
Compact,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[serde(untagged)]
pub enum TableConfig {
Specific { specific: Vec<String> },
Exclude { exclude: Vec<String> },
}
#[derive(Debug, Clone)]
pub enum SerdeEnable {
Both,
Serialize,
Deserialize,
None,
}
#[derive(Debug, Clone)]
pub enum Prelude {
Enabled,
Disabled,
AllowUnusedImports,
}
impl<'de> Deserialize<'de> for SerdeEnable {
fn deserialize<D>(deserializer: D) -> Result<SerdeEnable, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) if s == "serialize" => Ok(SerdeEnable::Serialize),
Value::String(s) if s == "deserialize" => Ok(SerdeEnable::Deserialize),
Value::Bool(true) => Ok(SerdeEnable::Both),
Value::Bool(false) => Ok(SerdeEnable::None),
_ => Err(serde::de::Error::custom(
"expected 'serialize', 'deserialize', 'true' or 'false'",
)),
}
}
}
impl Serialize for SerdeEnable {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
SerdeEnable::Both => serializer.serialize_bool(true),
SerdeEnable::Serialize => serializer.serialize_str("serialize"),
SerdeEnable::Deserialize => serializer.serialize_str("deserialize"),
SerdeEnable::None => serializer.serialize_bool(false),
}
}
}
impl<'de> Deserialize<'de> for Prelude {
fn deserialize<D>(deserializer: D) -> Result<Prelude, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::Bool(true) => Ok(Prelude::Enabled),
Value::Bool(false) => Ok(Prelude::Disabled),
Value::String(s) if s == "allow_unused_imports" => Ok(Prelude::AllowUnusedImports),
_ => Err(serde::de::Error::custom(
"expected 'true', 'false', or 'allow_unused_imports'",
)),
}
}
}
impl Serialize for Prelude {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Prelude::Enabled => serializer.serialize_bool(true),
Prelude::Disabled => serializer.serialize_bool(false),
Prelude::AllowUnusedImports => serializer.serialize_str("allow_unused_imports"),
}
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct Config {
pub db: DbConfig,
pub sea_orm: SeaOrmConfig,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct DbConfig {
pub database_schema: Option<String>,
pub max_connections: u32,
pub acquire_timeout: u64,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmConfig {
pub prelude: Prelude,
pub serde: SeaOrmSerdeConfig,
pub entity: SeaOrmEntityConfig,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmSerdeConfig {
pub enable: SerdeEnable,
pub skip_deserializing_primary_key: bool,
pub skip_hidden_column: bool,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmEntityConfig {
pub format: EntityFormat,
pub tables: SeaOrmTableConfig,
pub extra_derives: SeaOrmExtraDerivesConfig,
pub extra_attributes: SeaOrmExtraAttributesConfig,
pub date_time_crate: DateTimeCrate,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmTableConfig {
pub include_hidden: bool,
#[serde(flatten)]
pub table_config: Option<TableConfig>,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmExtraDerivesConfig {
pub model: Vec<String>,
#[serde(rename = "enum")]
pub eenum: Vec<String>,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmExtraAttributesConfig {
pub model: Vec<String>,
#[serde(rename = "enum")]
pub eenum: Vec<String>,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
enum DateTimeCrate {
Time,
Chrono,
}
impl From<DateTimeCrate> for CodegenDateTimeCrate {
fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
match date_time_crate {
DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
DateTimeCrate::Time => CodegenDateTimeCrate::Time,
}
}
}
impl SeaOrmTableConfig {
pub fn get_filter(&self) -> Box<dyn Fn(&String) -> bool> {
let include_hidden = self.include_hidden;
if let Some(table) = &self.table_config {
match table {
TableConfig::Specific { specific } => {
let specific = specific.clone();
Box::new(move |table: &String| {
(include_hidden || !table.starts_with('_')) && specific.contains(table)
})
}
TableConfig::Exclude { exclude } => {
let exclude = exclude.clone();
Box::new(move |table: &String| {
(include_hidden || !table.starts_with('_')) && !exclude.contains(table)
})
}
}
} else {
Box::new(move |table: &String| include_hidden || !table.starts_with('_'))
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
db: DbConfig {
database_schema: None,
max_connections: 10,
acquire_timeout: 5,
},
sea_orm: SeaOrmConfig {
prelude: Prelude::Enabled,
serde: SeaOrmSerdeConfig {
enable: SerdeEnable::None,
skip_deserializing_primary_key: false,
skip_hidden_column: false,
},
entity: SeaOrmEntityConfig {
format: EntityFormat::Compact,
tables: SeaOrmTableConfig {
include_hidden: false,
table_config: None,
},
extra_derives: SeaOrmExtraDerivesConfig {
model: Vec::new(),
eenum: Vec::new(),
},
extra_attributes: SeaOrmExtraAttributesConfig {
model: Vec::new(),
eenum: Vec::new(),
},
date_time_crate: DateTimeCrate::Chrono,
},
},
}
}
}

View File

@@ -17,20 +17,21 @@ pub async fn get_tables(
tracing::trace!(?url); tracing::trace!(?url);
let is_sqlite = url.scheme() == "sqlite"; let is_sqlite = url.scheme() == "sqlite";
let filter_tables = |table: &String| -> bool { let filter_tables = config.sea_orm.entity.tables.get_filter();
config.sea_orm.table.only.is_empty() || config.sea_orm.table.only.contains(table) // let filter_tables = |table: &String| -> bool {
}; // config.sea_orm.entity.table.only.is_empty() || config.sea_orm.table.only.contains(table)
// };
let filter_hidden_tables = |table: &str| -> bool { //
if false { // let filter_hidden_tables = |table: &str| -> bool {
true // if false {
} else { // true
!table.starts_with('_') // } else {
} // !table.starts_with('_')
}; // }
// };
let filter_skip_tables = //
|table: &String| -> bool { !config.sea_orm.table.exclude.contains(table) }; // let filter_skip_tables =
// |table: &String| -> bool { !config.sea_orm.table.exclude.contains(table) };
let database_name: &str = (if !is_sqlite { let database_name: &str = (if !is_sqlite {
let database_name = url let database_name = url
.path_segments() .path_segments()
@@ -53,8 +54,8 @@ pub async fn get_tables(
tracing::info!("Connecting to MySQL"); tracing::info!("Connecting to MySQL");
let connection = sqlx_connect::<MySql>( let connection = sqlx_connect::<MySql>(
config.sea_orm.max_connections, config.db.max_connections,
config.sea_orm.acquire_timeout, config.db.acquire_timeout,
url.as_str(), url.as_str(),
None, None,
) )
@@ -67,8 +68,8 @@ pub async fn get_tables(
.tables .tables
.into_iter() .into_iter()
.filter(|schema| filter_tables(&schema.info.name)) .filter(|schema| filter_tables(&schema.info.name))
.filter(|schema| filter_hidden_tables(&schema.info.name)) // .filter(|schema| filter_hidden_tables(&schema.info.name))
.filter(|schema| filter_skip_tables(&schema.info.name)) // .filter(|schema| filter_skip_tables(&schema.info.name))
.map(|schema| schema.write()) .map(|schema| schema.write())
.collect(); .collect();
(None, table_stmts) (None, table_stmts)
@@ -79,8 +80,8 @@ pub async fn get_tables(
tracing::info!("Connecting to SQLite"); tracing::info!("Connecting to SQLite");
let connection = sqlx_connect::<Sqlite>( let connection = sqlx_connect::<Sqlite>(
config.sea_orm.max_connections, config.db.max_connections,
config.sea_orm.acquire_timeout, config.db.acquire_timeout,
url.as_str(), url.as_str(),
None, None,
) )
@@ -96,8 +97,8 @@ pub async fn get_tables(
.tables .tables
.into_iter() .into_iter()
.filter(|schema| filter_tables(&schema.name)) .filter(|schema| filter_tables(&schema.name))
.filter(|schema| filter_hidden_tables(&schema.name)) // .filter(|schema| filter_hidden_tables(&schema.name))
.filter(|schema| filter_skip_tables(&schema.name)) // .filter(|schema| filter_skip_tables(&schema.name))
.map(|schema| schema.write()) .map(|schema| schema.write())
.collect(); .collect();
(None, table_stmts) (None, table_stmts)
@@ -107,14 +108,10 @@ pub async fn get_tables(
use sqlx::Postgres; use sqlx::Postgres;
tracing::info!("Connecting to Postgres"); tracing::info!("Connecting to Postgres");
let schema = config let schema = &config.db.database_schema.as_deref().unwrap_or("public");
.sea_orm
.database_schema
.as_deref()
.unwrap_or("public");
let connection = sqlx_connect::<Postgres>( let connection = sqlx_connect::<Postgres>(
config.sea_orm.max_connections, config.db.max_connections,
config.sea_orm.acquire_timeout, config.db.acquire_timeout,
url.as_str(), url.as_str(),
Some(schema), Some(schema),
) )
@@ -126,11 +123,11 @@ pub async fn get_tables(
.tables .tables
.into_iter() .into_iter()
.filter(|schema| filter_tables(&schema.info.name)) .filter(|schema| filter_tables(&schema.info.name))
.filter(|schema| filter_hidden_tables(&schema.info.name)) // .filter(|schema| filter_hidden_tables(&schema.info.name))
.filter(|schema| filter_skip_tables(&schema.info.name)) // .filter(|schema| filter_skip_tables(&schema.info.name))
.map(|schema| schema.write()) .map(|schema| schema.write())
.collect(); .collect();
(config.sea_orm.database_schema.clone(), table_stmts) (config.db.database_schema.clone(), table_stmts)
} }
_ => unimplemented!("{} is not supported", url.scheme()), _ => unimplemented!("{} is not supported", url.scheme()),
}; };

View File

@@ -1,10 +1,12 @@
mod config;
mod generate; mod generate;
use std::{fs, path::PathBuf, str::FromStr}; use std::{fs, path::PathBuf, str::FromStr};
use clap::Parser; use clap::Parser;
use color_eyre::{Report, Result}; use color_eyre::{Report, Result};
use config::Config;
use figment::{ use figment::{
providers::{Format, Serialized, Toml}, providers::{Format, Serialized, Yaml},
Figment, Figment,
}; };
use sea_orm_codegen::{ use sea_orm_codegen::{
@@ -13,97 +15,82 @@ use sea_orm_codegen::{
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Deserialize, Serialize, Debug, Clone)] // #[derive(Deserialize, Serialize, Debug, Clone)]
struct Config { // struct Config {
sea_orm: SeaOrmConfig, // sea_orm: SeaOrmConfig,
} // }
#[derive(Deserialize, Serialize, Debug, Clone)] // #[derive(Deserialize, Serialize, Debug, Clone)]
enum DateTimeCrate { // struct SeaOrmConfig {
Time, // expanded_format: bool,
Chrono, // table: SeaOrmTableConfig,
} // with_serde: String,
// with_prelude: String,
impl From<DateTimeCrate> for CodegenDateTimeCrate { // with_copy_enums: bool,
fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate { // serde_skip_deserializing_primary_key: bool,
match date_time_crate { // serde_skip_hidden_column: bool,
DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono, // max_connections: u32,
DateTimeCrate::Time => CodegenDateTimeCrate::Time, // acquire_timeout: u64,
} // database_schema: Option<String>,
} // date_format: DateTimeCrate,
} // }
// #[derive(Deserialize, Serialize, Debug, Clone)]
#[derive(Deserialize, Serialize, Debug, Clone)] // struct SeaOrmTableConfig {
struct SeaOrmConfig { // include_hidden: bool,
expanded_format: bool, // only: Vec<String>,
table: SeaOrmTableConfig, // exclude: Vec<String>,
with_serde: String, // }
with_prelude: String, //
with_copy_enums: bool, // impl Default for Config {
serde_skip_deserializing_primary_key: bool, // fn default() -> Self {
serde_skip_hidden_column: bool, // Self {
max_connections: u32, // sea_orm: SeaOrmConfig {
acquire_timeout: u64, // table: SeaOrmTableConfig {
database_schema: Option<String>, // include_hidden: false,
date_format: DateTimeCrate, // only: Vec::new(),
} // exclude: Vec::new(),
#[derive(Deserialize, Serialize, Debug, Clone)] // },
struct SeaOrmTableConfig { // database_schema: None,
include_hidden: bool, // max_connections: 10,
only: Vec<String>, // acquire_timeout: 30,
exclude: Vec<String>, // expanded_format: false,
} // with_copy_enums: false,
// with_serde: String::from("none"),
impl Default for Config { // with_prelude: String::from("none"),
fn default() -> Self { // serde_skip_hidden_column: false,
Self { // serde_skip_deserializing_primary_key: false,
sea_orm: SeaOrmConfig { // date_format: DateTimeCrate::Time,
table: SeaOrmTableConfig { // },
include_hidden: false, // }
only: Vec::new(), // }
exclude: Vec::new(), // }
},
database_schema: None,
max_connections: 10,
acquire_timeout: 30,
expanded_format: false,
with_copy_enums: false,
with_serde: String::from("none"),
with_prelude: String::from("none"),
serde_skip_hidden_column: false,
serde_skip_deserializing_primary_key: false,
date_format: DateTimeCrate::Time,
},
}
}
}
impl TryInto<EntityWriterContext> for Config {
type Error = Report;
fn try_into(self) -> Result<EntityWriterContext> {
Ok(EntityWriterContext::new(
self.sea_orm.expanded_format,
WithPrelude::from_str(&self.sea_orm.with_prelude)?,
WithSerde::from_str(&self.sea_orm.with_serde)?,
self.sea_orm.with_copy_enums,
self.sea_orm.date_format.into(),
self.sea_orm.database_schema,
false,
self.sea_orm.serde_skip_deserializing_primary_key,
self.sea_orm.serde_skip_hidden_column,
Vec::new(),
Vec::new(),
Vec::new(),
Vec::new(),
false,
false,
))
}
}
// impl TryInto<EntityWriterContext> for Config {
// type Error = Report;
// fn try_into(self) -> Result<EntityWriterContext> {
// Ok(EntityWriterContext::new(
// self.sea_orm.expanded_format,
// WithPrelude::from_str(&self.sea_orm.with_prelude)?,
// WithSerde::from_str(&self.sea_orm.with_serde)?,
// self.sea_orm.with_copy_enums,
// self.sea_orm.date_format.into(),
// self.sea_orm.database_schema,
// false,
// self.sea_orm.serde_skip_deserializing_primary_key,
// self.sea_orm.serde_skip_hidden_column,
// Vec::new(),
// Vec::new(),
// Vec::new(),
// Vec::new(),
// false,
// false,
// ))
// }
// }
//
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
struct Args { struct Args {
#[clap(short, long, default_value = "generator.toml")] #[clap(short, long, default_value = "generator.yml")]
config: String, config: String,
#[clap(short, long, env = "DATABASE_URL")] #[clap(short, long, env = "DATABASE_URL")]
database_url: String, database_url: String,
@@ -115,13 +102,13 @@ async fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
let config: Config = Figment::new() let config: Config = Figment::new()
.merge(Serialized::defaults(Config::default())) // .merge(Serialized::defaults(Config::default()))
.merge(Toml::file(&args.config)) .merge(Yaml::file(&args.config))
.extract()?; .extract()?;
tracing::info!(?config); tracing::info!(?config);
tracing::info!(?args); tracing::info!(?args);
let (_, table_stmts) = generate::get_tables(args.database_url, &config).await?; // let (_, table_stmts) = generate::get_tables(args.database_url, &config).await?;
let writer_context = config.clone().try_into()?; // let writer_context = config.clone().try_into()?;
let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context); // let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
Ok(()) Ok(())
} }