This commit is contained in:
2025-04-02 14:09:05 +04:00
parent dc9695dac0
commit 4b2a9f5be0
4 changed files with 169 additions and 88 deletions

View File

@@ -7,6 +7,7 @@ edition = "2021"
clap = { version = "4.5.32", features = ["derive", "env"] } clap = { version = "4.5.32", features = ["derive", "env"] }
color-eyre = "0.6.3" color-eyre = "0.6.3"
figment = { version = "0.10.19", features = ["yaml"] } figment = { version = "0.10.19", features = ["yaml"] }
heck = "0.5.0"
inquire = "0.7.5" inquire = "0.7.5"
sea-orm-codegen = "1.1.8" sea-orm-codegen = "1.1.8"
sea-schema = { version = "0.16.1", features = ["sqlx-all"] } sea-schema = { version = "0.16.1", features = ["sqlx-all"] }

View File

@@ -1,7 +1,12 @@
use sea_orm_codegen::DateTimeCrate as CodegenDateTimeCrate; use std::path::PathBuf;
use color_eyre::Report;
use sea_orm_codegen::{
DateTimeCrate as CodegenDateTimeCrate, EntityWriterContext, WithPrelude, WithSerde,
};
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
use serde_yaml::Value; use serde_yaml::Value;
use tracing::info; use tracing::instrument;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
@@ -98,6 +103,12 @@ impl Serialize for Prelude {
pub struct Config { pub struct Config {
pub db: DbConfig, pub db: DbConfig,
pub sea_orm: SeaOrmConfig, pub sea_orm: SeaOrmConfig,
pub output: OutputConfig,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct OutputConfig {
pub path: PathBuf,
} }
#[derive(Deserialize, Serialize, Debug, Clone)] #[derive(Deserialize, Serialize, Debug, Clone)]
@@ -128,11 +139,13 @@ pub struct SeaOrmEntityConfig {
pub extra_derives: SeaOrmExtraDerivesConfig, pub extra_derives: SeaOrmExtraDerivesConfig,
pub extra_attributes: SeaOrmExtraAttributesConfig, pub extra_attributes: SeaOrmExtraAttributesConfig,
pub date_time_crate: DateTimeCrate, pub date_time_crate: DateTimeCrate,
pub with_copy_enums: bool,
} }
#[derive(Deserialize, Serialize, Debug, Clone)] #[derive(Deserialize, Serialize, Debug, Clone)]
pub struct SeaOrmTableConfig { pub struct SeaOrmTableConfig {
pub include_hidden: bool, pub include_hidden: bool,
pub skip_seaql_migrations: bool,
#[serde(flatten)] #[serde(flatten)]
pub table_config: Option<TableConfig>, pub table_config: Option<TableConfig>,
} }
@@ -149,7 +162,8 @@ pub struct SeaOrmExtraAttributesConfig {
pub eenum: Vec<String>, pub eenum: Vec<String>,
} }
#[derive(Deserialize, Serialize, Debug, Clone)] #[derive(Deserialize, Serialize, Debug, Clone)]
enum DateTimeCrate { #[serde(rename_all = "lowercase")]
pub enum DateTimeCrate {
Time, Time,
Chrono, Chrono,
} }
@@ -181,8 +195,13 @@ impl SeaOrmTableConfig {
}) })
} }
} }
} else if self.skip_seaql_migrations {
Box::new(move |table: &String| {
(include_hidden || !table.starts_with('_'))
&& !table.starts_with("seaql_migrations")
})
} else { } else {
Box::new(move |table: &String| include_hidden || !table.starts_with('_')) Box::new(move |table: &String| (include_hidden || !table.starts_with('_')))
} }
} }
} }
@@ -206,6 +225,7 @@ impl Default for Config {
format: EntityFormat::Compact, format: EntityFormat::Compact,
tables: SeaOrmTableConfig { tables: SeaOrmTableConfig {
include_hidden: false, include_hidden: false,
skip_seaql_migrations: true,
table_config: None, table_config: None,
}, },
extra_derives: SeaOrmExtraDerivesConfig { extra_derives: SeaOrmExtraDerivesConfig {
@@ -217,8 +237,60 @@ impl Default for Config {
eenum: Vec::new(), eenum: Vec::new(),
}, },
date_time_crate: DateTimeCrate::Chrono, date_time_crate: DateTimeCrate::Chrono,
with_copy_enums: false,
}, },
}, },
output: OutputConfig {
path: PathBuf::from("./entities"),
},
} }
} }
} }
impl EntityFormat {
pub fn is_expanded(&self) -> bool {
matches!(self, EntityFormat::Expanded)
}
}
impl From<Prelude> for WithPrelude {
fn from(val: Prelude) -> Self {
match val {
Prelude::Enabled => WithPrelude::All,
Prelude::Disabled => WithPrelude::None,
Prelude::AllowUnusedImports => WithPrelude::AllAllowUnusedImports,
}
}
}
impl From<SerdeEnable> for WithSerde {
fn from(val: SerdeEnable) -> Self {
match val {
SerdeEnable::Both => WithSerde::Both,
SerdeEnable::Serialize => WithSerde::Serialize,
SerdeEnable::Deserialize => WithSerde::Deserialize,
SerdeEnable::None => WithSerde::None,
}
}
}
impl From<Config> for EntityWriterContext {
fn from(val: Config) -> Self {
EntityWriterContext::new(
val.sea_orm.entity.format.is_expanded(),
val.sea_orm.prelude.into(),
val.sea_orm.serde.enable.into(),
val.sea_orm.entity.with_copy_enums,
val.sea_orm.entity.date_time_crate.into(),
val.db.database_schema,
false,
val.sea_orm.serde.skip_deserializing_primary_key,
val.sea_orm.serde.skip_hidden_column,
val.sea_orm.entity.extra_derives.model,
val.sea_orm.entity.extra_attributes.model,
val.sea_orm.entity.extra_derives.eenum,
val.sea_orm.entity.extra_attributes.eenum,
false,
false,
)
}
}

View File

@@ -1,13 +1,15 @@
use core::time; use core::time;
use crate::Config;
use color_eyre::{ use color_eyre::{
eyre::{eyre, ContextCompat, Report}, eyre::{eyre, ContextCompat, Report},
Result, Result,
}; };
use sea_schema::sea_query::TableCreateStatement; use sea_orm_codegen::OutputFile;
use sea_schema::sea_query::{self, TableCreateStatement};
use tokio::{fs, task::JoinSet};
use url::Url; use url::Url;
use crate::Config;
pub async fn get_tables( pub async fn get_tables(
database_url: String, database_url: String,
config: &Config, config: &Config,
@@ -18,6 +20,7 @@ pub async fn get_tables(
let is_sqlite = url.scheme() == "sqlite"; let is_sqlite = url.scheme() == "sqlite";
let filter_tables = config.sea_orm.entity.tables.get_filter(); let filter_tables = config.sea_orm.entity.tables.get_filter();
// let filter_tables = |table: &String| -> bool { // let filter_tables = |table: &String| -> bool {
// config.sea_orm.entity.table.only.is_empty() || config.sea_orm.table.only.contains(table) // config.sea_orm.entity.table.only.is_empty() || config.sea_orm.table.only.contains(table)
// }; // };
@@ -135,7 +138,6 @@ pub async fn get_tables(
Ok((schema_name, table_stmts)) Ok((schema_name, table_stmts))
} }
async fn sqlx_connect<DB>( async fn sqlx_connect<DB>(
max_connections: u32, max_connections: u32,
acquire_timeout: u64, acquire_timeout: u64,
@@ -164,3 +166,53 @@ where
} }
pool_options.connect(url).await.map_err(Into::into) pool_options.connect(url).await.map_err(Into::into)
} }
pub async fn generate_models(
tables: Vec<TableCreateStatement>,
config: Config,
) -> Result<Vec<OutputFile>> {
tracing::info!(?tables);
let output_path = config.output.path;
let files = tables
.into_iter()
.map(|table| {
let output_path = output_path.clone();
async move {
let table_name = match table.get_table_name() {
Some(table_ref) => match table_ref {
sea_query::TableRef::Table(t)
| sea_query::TableRef::SchemaTable(_, t)
| sea_query::TableRef::DatabaseSchemaTable(_, _, t)
| sea_query::TableRef::TableAlias(t, _)
| sea_query::TableRef::SchemaTableAlias(_, t, _)
| sea_query::TableRef::DatabaseSchemaTableAlias(_, _, t, _) => {
t.to_string()
}
_ => unimplemented!(),
},
None => return Err(eyre!("Table name not found")),
};
let file_path = output_path.join(&table_name);
let exists = file_path.exists();
let content = match exists {
true => {
// let file_content = fs::read_to_string(path)
}
false => {}
};
Ok(OutputFile {
name: format!("{}.rs", table_name),
content: String::new(),
})
}
})
.collect::<JoinSet<Result<OutputFile>>>();
let files = files
.join_all()
.await
.into_iter()
.collect::<Result<Vec<OutputFile>>>()?;
Ok(files)
}
//

View File

@@ -1,93 +1,21 @@
mod config; mod config;
mod generate; mod generate;
use std::{fs, path::PathBuf, str::FromStr}; use std::{path::PathBuf, str::FromStr};
use clap::Parser; use clap::Parser;
use color_eyre::{Report, Result}; use color_eyre::{eyre::eyre, Report, Result};
use config::Config; use config::Config;
use figment::{ use figment::{
providers::{Format, Serialized, Yaml}, providers::{Format, Serialized, Yaml},
Figment, Figment,
}; };
use sea_orm_codegen::{ use sea_orm_codegen::{
DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, WithPrelude, DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, OutputFile,
WithSerde, WithPrelude, WithSerde,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::{fs, io::AsyncWriteExt, process::Command};
// #[derive(Deserialize, Serialize, Debug, Clone)]
// struct Config {
// sea_orm: SeaOrmConfig,
// }
// #[derive(Deserialize, Serialize, Debug, Clone)]
// struct SeaOrmConfig {
// expanded_format: bool,
// table: SeaOrmTableConfig,
// with_serde: String,
// with_prelude: String,
// with_copy_enums: bool,
// serde_skip_deserializing_primary_key: bool,
// serde_skip_hidden_column: bool,
// max_connections: u32,
// acquire_timeout: u64,
// database_schema: Option<String>,
// date_format: DateTimeCrate,
// }
// #[derive(Deserialize, Serialize, Debug, Clone)]
// struct SeaOrmTableConfig {
// include_hidden: bool,
// only: Vec<String>,
// exclude: Vec<String>,
// }
//
// impl Default for Config {
// fn default() -> Self {
// Self {
// sea_orm: SeaOrmConfig {
// table: SeaOrmTableConfig {
// include_hidden: false,
// only: Vec::new(),
// exclude: Vec::new(),
// },
// database_schema: None,
// max_connections: 10,
// acquire_timeout: 30,
// expanded_format: false,
// with_copy_enums: false,
// with_serde: String::from("none"),
// with_prelude: String::from("none"),
// serde_skip_hidden_column: false,
// serde_skip_deserializing_primary_key: false,
// date_format: DateTimeCrate::Time,
// },
// }
// }
// }
// impl TryInto<EntityWriterContext> for Config {
// type Error = Report;
// fn try_into(self) -> Result<EntityWriterContext> {
// Ok(EntityWriterContext::new(
// self.sea_orm.expanded_format,
// WithPrelude::from_str(&self.sea_orm.with_prelude)?,
// WithSerde::from_str(&self.sea_orm.with_serde)?,
// self.sea_orm.with_copy_enums,
// self.sea_orm.date_format.into(),
// self.sea_orm.database_schema,
// false,
// self.sea_orm.serde_skip_deserializing_primary_key,
// self.sea_orm.serde_skip_hidden_column,
// Vec::new(),
// Vec::new(),
// Vec::new(),
// Vec::new(),
// false,
// false,
// ))
// }
// }
//
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
struct Args { struct Args {
#[clap(short, long, default_value = "generator.yml")] #[clap(short, long, default_value = "generator.yml")]
@@ -102,13 +30,41 @@ async fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
let config: Config = Figment::new() let config: Config = Figment::new()
// .merge(Serialized::defaults(Config::default())) .merge(Serialized::defaults(Config::default()))
.merge(Yaml::file(&args.config)) .merge(Yaml::file(&args.config))
.extract()?; .extract()?;
tracing::info!(?config); tracing::info!(?config);
tracing::info!(?args); tracing::info!(?args);
// let (_, table_stmts) = generate::get_tables(args.database_url, &config).await?; let output_dir = &config.output.path;
// let writer_context = config.clone().try_into()?; let output_internal_entities = output_dir.join("_entities");
// let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context); let (_, table_stmts) = generate::get_tables(args.database_url, &config).await?;
let writer_context = config.clone().into();
let output = EntityTransformer::transform(table_stmts.clone())?.generate(&writer_context);
let mut files = output
.files
.into_iter()
.map(|OutputFile { name, content }| (output_internal_entities.join(name), content))
.collect::<Vec<_>>();
let generate_files = generate::generate_models(table_stmts, config.clone())
.await?
.into_iter()
.map(|OutputFile { name, content }| (output_dir.join(name), content))
.collect::<Vec<_>>();
files.extend(generate_files);
tracing::info!("Generated {} files", files.len());
fs::create_dir_all(&output_internal_entities).await?;
for (file_path, content) in files.iter() {
tracing::info!(?file_path, "Writing file");
let mut file = fs::File::create(&file_path).await?;
file.write_all(content.as_bytes()).await?;
}
for (file_path, ..) in files.iter() {
tracing::info!(?file_path, "Running rustfmt");
let exit_status = Command::new("rustfmt").arg(file_path).status().await?; // Get the status code
if !exit_status.success() {
// Propagate the error if any
return Err(eyre!("Failed to run rustfmt"));
}
}
Ok(()) Ok(())
} }