This commit is contained in:
2025-04-01 21:33:03 +04:00
parent f012a96173
commit dc9695dac0
4 changed files with 335 additions and 126 deletions

View File

@@ -6,11 +6,12 @@ edition = "2021"
[dependencies]
clap = { version = "4.5.32", features = ["derive", "env"] }
color-eyre = "0.6.3"
confique = { version = "0.3.0", features = ["yaml", "toml"] }
figment = { version = "0.10.19", features = ["yaml"] }
inquire = "0.7.5"
sea-orm-codegen = "1.1.8"
sea-schema = { version = "0.16.1", features = ["sqlx-all"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_yaml = "0.9.34"
sqlx = { version = "0.8.3", features = ["mysql", "postgres", "sqlite"] }
syn = { version = "2.0.100", features = ["extra-traits", "full"] }
tokio = { version = "1.44.1", features = ["full"] }

224
src/config.rs Normal file
View File

@@ -0,0 +1,224 @@
use sea_orm_codegen::DateTimeCrate as CodegenDateTimeCrate;
use serde::{Deserialize, Deserializer, Serialize};
use serde_yaml::Value;
use tracing::info;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EntityFormat {
Expanded,
Compact,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[serde(untagged)]
pub enum TableConfig {
Specific { specific: Vec<String> },
Exclude { exclude: Vec<String> },
}
#[derive(Debug, Clone)]
pub enum SerdeEnable {
Both,
Serialize,
Deserialize,
None,
}
#[derive(Debug, Clone)]
pub enum Prelude {
Enabled,
Disabled,
AllowUnusedImports,
}
impl<'de> Deserialize<'de> for SerdeEnable {
fn deserialize<D>(deserializer: D) -> Result<SerdeEnable, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) if s == "serialize" => Ok(SerdeEnable::Serialize),
Value::String(s) if s == "deserialize" => Ok(SerdeEnable::Deserialize),
Value::Bool(true) => Ok(SerdeEnable::Both),
Value::Bool(false) => Ok(SerdeEnable::None),
_ => Err(serde::de::Error::custom(
"expected 'serialize', 'deserialize', 'true' or 'false'",
)),
}
}
}
impl Serialize for SerdeEnable {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
SerdeEnable::Both => serializer.serialize_bool(true),
SerdeEnable::Serialize => serializer.serialize_str("serialize"),
SerdeEnable::Deserialize => serializer.serialize_str("deserialize"),
SerdeEnable::None => serializer.serialize_bool(false),
}
}
}
impl<'de> Deserialize<'de> for Prelude {
fn deserialize<D>(deserializer: D) -> Result<Prelude, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::Bool(true) => Ok(Prelude::Enabled),
Value::Bool(false) => Ok(Prelude::Disabled),
Value::String(s) if s == "allow_unused_imports" => Ok(Prelude::AllowUnusedImports),
_ => Err(serde::de::Error::custom(
"expected 'true', 'false', or 'allow_unused_imports'",
)),
}
}
}
impl Serialize for Prelude {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Prelude::Enabled => serializer.serialize_bool(true),
Prelude::Disabled => serializer.serialize_bool(false),
Prelude::AllowUnusedImports => serializer.serialize_str("allow_unused_imports"),
}
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct Config {
pub db: DbConfig,
pub sea_orm: SeaOrmConfig,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct DbConfig {
pub database_schema: Option<String>,
pub max_connections: u32,
pub acquire_timeout: u64,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmConfig {
pub prelude: Prelude,
pub serde: SeaOrmSerdeConfig,
pub entity: SeaOrmEntityConfig,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmSerdeConfig {
pub enable: SerdeEnable,
pub skip_deserializing_primary_key: bool,
pub skip_hidden_column: bool,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmEntityConfig {
pub format: EntityFormat,
pub tables: SeaOrmTableConfig,
pub extra_derives: SeaOrmExtraDerivesConfig,
pub extra_attributes: SeaOrmExtraAttributesConfig,
pub date_time_crate: DateTimeCrate,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmTableConfig {
pub include_hidden: bool,
#[serde(flatten)]
pub table_config: Option<TableConfig>,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmExtraDerivesConfig {
pub model: Vec<String>,
#[serde(rename = "enum")]
pub eenum: Vec<String>,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmExtraAttributesConfig {
pub model: Vec<String>,
#[serde(rename = "enum")]
pub eenum: Vec<String>,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
enum DateTimeCrate {
Time,
Chrono,
}
impl From<DateTimeCrate> for CodegenDateTimeCrate {
fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
match date_time_crate {
DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
DateTimeCrate::Time => CodegenDateTimeCrate::Time,
}
}
}
impl SeaOrmTableConfig {
pub fn get_filter(&self) -> Box<dyn Fn(&String) -> bool> {
let include_hidden = self.include_hidden;
if let Some(table) = &self.table_config {
match table {
TableConfig::Specific { specific } => {
let specific = specific.clone();
Box::new(move |table: &String| {
(include_hidden || !table.starts_with('_')) && specific.contains(table)
})
}
TableConfig::Exclude { exclude } => {
let exclude = exclude.clone();
Box::new(move |table: &String| {
(include_hidden || !table.starts_with('_')) && !exclude.contains(table)
})
}
}
} else {
Box::new(move |table: &String| include_hidden || !table.starts_with('_'))
}
}
}
impl Default for Config {
fn default() -> Self {
Self {
db: DbConfig {
database_schema: None,
max_connections: 10,
acquire_timeout: 5,
},
sea_orm: SeaOrmConfig {
prelude: Prelude::Enabled,
serde: SeaOrmSerdeConfig {
enable: SerdeEnable::None,
skip_deserializing_primary_key: false,
skip_hidden_column: false,
},
entity: SeaOrmEntityConfig {
format: EntityFormat::Compact,
tables: SeaOrmTableConfig {
include_hidden: false,
table_config: None,
},
extra_derives: SeaOrmExtraDerivesConfig {
model: Vec::new(),
eenum: Vec::new(),
},
extra_attributes: SeaOrmExtraAttributesConfig {
model: Vec::new(),
eenum: Vec::new(),
},
date_time_crate: DateTimeCrate::Chrono,
},
},
}
}
}

View File

@@ -17,20 +17,21 @@ pub async fn get_tables(
tracing::trace!(?url);
let is_sqlite = url.scheme() == "sqlite";
let filter_tables = |table: &String| -> bool {
config.sea_orm.table.only.is_empty() || config.sea_orm.table.only.contains(table)
};
let filter_hidden_tables = |table: &str| -> bool {
if false {
true
} else {
!table.starts_with('_')
}
};
let filter_skip_tables =
|table: &String| -> bool { !config.sea_orm.table.exclude.contains(table) };
let filter_tables = config.sea_orm.entity.tables.get_filter();
// let filter_tables = |table: &String| -> bool {
// config.sea_orm.entity.table.only.is_empty() || config.sea_orm.table.only.contains(table)
// };
//
// let filter_hidden_tables = |table: &str| -> bool {
// if false {
// true
// } else {
// !table.starts_with('_')
// }
// };
//
// let filter_skip_tables =
// |table: &String| -> bool { !config.sea_orm.table.exclude.contains(table) };
let database_name: &str = (if !is_sqlite {
let database_name = url
.path_segments()
@@ -53,8 +54,8 @@ pub async fn get_tables(
tracing::info!("Connecting to MySQL");
let connection = sqlx_connect::<MySql>(
config.sea_orm.max_connections,
config.sea_orm.acquire_timeout,
config.db.max_connections,
config.db.acquire_timeout,
url.as_str(),
None,
)
@@ -67,8 +68,8 @@ pub async fn get_tables(
.tables
.into_iter()
.filter(|schema| filter_tables(&schema.info.name))
.filter(|schema| filter_hidden_tables(&schema.info.name))
.filter(|schema| filter_skip_tables(&schema.info.name))
// .filter(|schema| filter_hidden_tables(&schema.info.name))
// .filter(|schema| filter_skip_tables(&schema.info.name))
.map(|schema| schema.write())
.collect();
(None, table_stmts)
@@ -79,8 +80,8 @@ pub async fn get_tables(
tracing::info!("Connecting to SQLite");
let connection = sqlx_connect::<Sqlite>(
config.sea_orm.max_connections,
config.sea_orm.acquire_timeout,
config.db.max_connections,
config.db.acquire_timeout,
url.as_str(),
None,
)
@@ -96,8 +97,8 @@ pub async fn get_tables(
.tables
.into_iter()
.filter(|schema| filter_tables(&schema.name))
.filter(|schema| filter_hidden_tables(&schema.name))
.filter(|schema| filter_skip_tables(&schema.name))
// .filter(|schema| filter_hidden_tables(&schema.name))
// .filter(|schema| filter_skip_tables(&schema.name))
.map(|schema| schema.write())
.collect();
(None, table_stmts)
@@ -107,14 +108,10 @@ pub async fn get_tables(
use sqlx::Postgres;
tracing::info!("Connecting to Postgres");
let schema = config
.sea_orm
.database_schema
.as_deref()
.unwrap_or("public");
let schema = &config.db.database_schema.as_deref().unwrap_or("public");
let connection = sqlx_connect::<Postgres>(
config.sea_orm.max_connections,
config.sea_orm.acquire_timeout,
config.db.max_connections,
config.db.acquire_timeout,
url.as_str(),
Some(schema),
)
@@ -126,11 +123,11 @@ pub async fn get_tables(
.tables
.into_iter()
.filter(|schema| filter_tables(&schema.info.name))
.filter(|schema| filter_hidden_tables(&schema.info.name))
.filter(|schema| filter_skip_tables(&schema.info.name))
// .filter(|schema| filter_hidden_tables(&schema.info.name))
// .filter(|schema| filter_skip_tables(&schema.info.name))
.map(|schema| schema.write())
.collect();
(config.sea_orm.database_schema.clone(), table_stmts)
(config.db.database_schema.clone(), table_stmts)
}
_ => unimplemented!("{} is not supported", url.scheme()),
};

View File

@@ -1,10 +1,12 @@
mod config;
mod generate;
use std::{fs, path::PathBuf, str::FromStr};
use clap::Parser;
use color_eyre::{Report, Result};
use config::Config;
use figment::{
providers::{Format, Serialized, Toml},
providers::{Format, Serialized, Yaml},
Figment,
};
use sea_orm_codegen::{
@@ -13,97 +15,82 @@ use sea_orm_codegen::{
};
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Serialize, Debug, Clone)]
struct Config {
sea_orm: SeaOrmConfig,
}
// #[derive(Deserialize, Serialize, Debug, Clone)]
// struct Config {
// sea_orm: SeaOrmConfig,
// }
#[derive(Deserialize, Serialize, Debug, Clone)]
enum DateTimeCrate {
Time,
Chrono,
}
impl From<DateTimeCrate> for CodegenDateTimeCrate {
fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
match date_time_crate {
DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
DateTimeCrate::Time => CodegenDateTimeCrate::Time,
}
}
}
#[derive(Deserialize, Serialize, Debug, Clone)]
struct SeaOrmConfig {
expanded_format: bool,
table: SeaOrmTableConfig,
with_serde: String,
with_prelude: String,
with_copy_enums: bool,
serde_skip_deserializing_primary_key: bool,
serde_skip_hidden_column: bool,
max_connections: u32,
acquire_timeout: u64,
database_schema: Option<String>,
date_format: DateTimeCrate,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
struct SeaOrmTableConfig {
include_hidden: bool,
only: Vec<String>,
exclude: Vec<String>,
}
impl Default for Config {
fn default() -> Self {
Self {
sea_orm: SeaOrmConfig {
table: SeaOrmTableConfig {
include_hidden: false,
only: Vec::new(),
exclude: Vec::new(),
},
database_schema: None,
max_connections: 10,
acquire_timeout: 30,
expanded_format: false,
with_copy_enums: false,
with_serde: String::from("none"),
with_prelude: String::from("none"),
serde_skip_hidden_column: false,
serde_skip_deserializing_primary_key: false,
date_format: DateTimeCrate::Time,
},
}
}
}
impl TryInto<EntityWriterContext> for Config {
type Error = Report;
fn try_into(self) -> Result<EntityWriterContext> {
Ok(EntityWriterContext::new(
self.sea_orm.expanded_format,
WithPrelude::from_str(&self.sea_orm.with_prelude)?,
WithSerde::from_str(&self.sea_orm.with_serde)?,
self.sea_orm.with_copy_enums,
self.sea_orm.date_format.into(),
self.sea_orm.database_schema,
false,
self.sea_orm.serde_skip_deserializing_primary_key,
self.sea_orm.serde_skip_hidden_column,
Vec::new(),
Vec::new(),
Vec::new(),
Vec::new(),
false,
false,
))
}
}
// #[derive(Deserialize, Serialize, Debug, Clone)]
// struct SeaOrmConfig {
// expanded_format: bool,
// table: SeaOrmTableConfig,
// with_serde: String,
// with_prelude: String,
// with_copy_enums: bool,
// serde_skip_deserializing_primary_key: bool,
// serde_skip_hidden_column: bool,
// max_connections: u32,
// acquire_timeout: u64,
// database_schema: Option<String>,
// date_format: DateTimeCrate,
// }
// #[derive(Deserialize, Serialize, Debug, Clone)]
// struct SeaOrmTableConfig {
// include_hidden: bool,
// only: Vec<String>,
// exclude: Vec<String>,
// }
//
// impl Default for Config {
// fn default() -> Self {
// Self {
// sea_orm: SeaOrmConfig {
// table: SeaOrmTableConfig {
// include_hidden: false,
// only: Vec::new(),
// exclude: Vec::new(),
// },
// database_schema: None,
// max_connections: 10,
// acquire_timeout: 30,
// expanded_format: false,
// with_copy_enums: false,
// with_serde: String::from("none"),
// with_prelude: String::from("none"),
// serde_skip_hidden_column: false,
// serde_skip_deserializing_primary_key: false,
// date_format: DateTimeCrate::Time,
// },
// }
// }
// }
// impl TryInto<EntityWriterContext> for Config {
// type Error = Report;
// fn try_into(self) -> Result<EntityWriterContext> {
// Ok(EntityWriterContext::new(
// self.sea_orm.expanded_format,
// WithPrelude::from_str(&self.sea_orm.with_prelude)?,
// WithSerde::from_str(&self.sea_orm.with_serde)?,
// self.sea_orm.with_copy_enums,
// self.sea_orm.date_format.into(),
// self.sea_orm.database_schema,
// false,
// self.sea_orm.serde_skip_deserializing_primary_key,
// self.sea_orm.serde_skip_hidden_column,
// Vec::new(),
// Vec::new(),
// Vec::new(),
// Vec::new(),
// false,
// false,
// ))
// }
// }
//
#[derive(Parser, Debug)]
struct Args {
#[clap(short, long, default_value = "generator.toml")]
#[clap(short, long, default_value = "generator.yml")]
config: String,
#[clap(short, long, env = "DATABASE_URL")]
database_url: String,
@@ -115,13 +102,13 @@ async fn main() -> Result<()> {
let args = Args::parse();
let config: Config = Figment::new()
.merge(Serialized::defaults(Config::default()))
.merge(Toml::file(&args.config))
// .merge(Serialized::defaults(Config::default()))
.merge(Yaml::file(&args.config))
.extract()?;
tracing::info!(?config);
tracing::info!(?args);
let (_, table_stmts) = generate::get_tables(args.database_url, &config).await?;
let writer_context = config.clone().try_into()?;
let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
// let (_, table_stmts) = generate::get_tables(args.database_url, &config).await?;
// let writer_context = config.clone().try_into()?;
// let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
Ok(())
}