rewrite generator entirely, again
This commit is contained in:
236
src/generator/models/column.rs
Normal file
236
src/generator/models/column.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
use color_eyre::{eyre::ContextCompat, Result};
|
||||
use comfy_table::Cell;
|
||||
use heck::ToUpperCamelCase;
|
||||
use sea_schema::sea_query::{ColumnDef, ColumnSpec, ColumnType, IndexCreateStatement};
|
||||
|
||||
use crate::config::{sea_orm_config::DateTimeCrate, Config};
|
||||
|
||||
use super::{discover::DbType, ModelConfig};
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Column {
|
||||
pub name: String,
|
||||
pub col_type: ColumnType,
|
||||
pub attrs: Vec<ColumnSpec>,
|
||||
}
|
||||
|
||||
impl Column {
|
||||
pub fn new(column: ColumnDef, index: Option<IndexCreateStatement>) -> Result<Self> {
|
||||
let name = column.get_column_name();
|
||||
let col_type = column
|
||||
.get_column_type()
|
||||
.context("Unable to get column type")?
|
||||
.clone();
|
||||
let mut attrs = column.get_column_spec().clone();
|
||||
if let Some(index) = index {
|
||||
if index.is_unique_key() {
|
||||
attrs.push(ColumnSpec::UniqueKey)
|
||||
}
|
||||
if index.is_primary_key() {
|
||||
attrs.push(ColumnSpec::PrimaryKey);
|
||||
}
|
||||
}
|
||||
Ok(Column {
|
||||
name: name.to_string(),
|
||||
col_type,
|
||||
attrs: attrs.to_vec(),
|
||||
})
|
||||
}
|
||||
pub fn get_info_row(&self, config: &ModelConfig) -> Result<Vec<Cell>> {
|
||||
let column_type_rust = self.get_rs_type(&config.comment.date_time_crate);
|
||||
let column_type = self.get_db_type(&config.db_type);
|
||||
let attrs = self.attrs_to_string();
|
||||
let mut cols = Vec::new();
|
||||
if config.comment.column_name {
|
||||
cols.push(Cell::new(self.name.clone()))
|
||||
}
|
||||
if config.comment.column_name {
|
||||
cols.push(Cell::new(column_type.clone()))
|
||||
}
|
||||
if config.comment.column_rust_type {
|
||||
cols.push(Cell::new(column_type_rust.clone()))
|
||||
}
|
||||
if config.comment.column_attributes {
|
||||
cols.push(Cell::new(attrs.clone()));
|
||||
}
|
||||
Ok(cols)
|
||||
}
|
||||
pub fn attrs_to_string(&self) -> String {
|
||||
self.attrs
|
||||
.iter()
|
||||
.filter_map(Self::get_addr_type)
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(", ")
|
||||
}
|
||||
pub fn get_addr_type(attr: &ColumnSpec) -> Option<String> {
|
||||
match attr {
|
||||
ColumnSpec::PrimaryKey => Some("primary key".to_owned()),
|
||||
ColumnSpec::Null => todo!(),
|
||||
ColumnSpec::NotNull => Some("not null".to_owned()),
|
||||
ColumnSpec::Default(simple_expr) => todo!(),
|
||||
ColumnSpec::AutoIncrement => Some("autoincrement".to_owned()),
|
||||
ColumnSpec::UniqueKey => Some("unique key".to_owned()),
|
||||
ColumnSpec::Check(simple_expr) => todo!(),
|
||||
ColumnSpec::Generated { expr, stored } => todo!(),
|
||||
ColumnSpec::Extra(_) => todo!(),
|
||||
ColumnSpec::Comment(_) => todo!(),
|
||||
}
|
||||
}
|
||||
pub fn get_db_type(&self, db_type: &DbType) -> String {
|
||||
fn write_db_type(col_type: &ColumnType, db_type: &DbType) -> String {
|
||||
#[allow(unreachable_patterns)]
|
||||
match (col_type, db_type) {
|
||||
(ColumnType::Char(_), _) => "char".to_owned(),
|
||||
(ColumnType::String(_), _) => "varchar".to_owned(),
|
||||
(ColumnType::Text, _) => "text".to_owned(),
|
||||
(ColumnType::TinyInteger, DbType::MySql | DbType::Sqlite) => "tinyint".to_owned(),
|
||||
(ColumnType::TinyInteger, DbType::Postgres) => "smallint".to_owned(),
|
||||
(ColumnType::SmallInteger, _) => "smallint".to_owned(),
|
||||
(ColumnType::Integer, DbType::MySql) => "int".to_owned(),
|
||||
(ColumnType::Integer, _) => "integer".to_owned(),
|
||||
(ColumnType::BigInteger, DbType::MySql | DbType::Postgres) => "bigint".to_owned(),
|
||||
(ColumnType::BigInteger, DbType::Sqlite) => "integer".to_owned(),
|
||||
(ColumnType::TinyUnsigned, DbType::MySql) => "tinyint unsigned".to_owned(),
|
||||
(ColumnType::TinyUnsigned, DbType::Postgres) => "smallint".to_owned(),
|
||||
(ColumnType::TinyUnsigned, DbType::Sqlite) => "tinyint".to_owned(),
|
||||
(ColumnType::SmallUnsigned, DbType::MySql) => "smallint unsigned".to_owned(),
|
||||
(ColumnType::SmallUnsigned, DbType::Postgres | DbType::Sqlite) => {
|
||||
"smallint".to_owned()
|
||||
}
|
||||
(ColumnType::Unsigned, DbType::MySql) => "int unsigned".to_owned(),
|
||||
(ColumnType::Unsigned, DbType::Postgres | DbType::Sqlite) => "integer".to_owned(),
|
||||
(ColumnType::BigUnsigned, DbType::MySql) => "bigint unsigned".to_owned(),
|
||||
(ColumnType::BigUnsigned, DbType::Postgres) => "bigint".to_owned(),
|
||||
(ColumnType::BigUnsigned, DbType::Sqlite) => "integer".to_owned(),
|
||||
(ColumnType::Float, DbType::MySql | DbType::Sqlite) => "float".to_owned(),
|
||||
(ColumnType::Float, DbType::Postgres) => "real".to_owned(),
|
||||
(ColumnType::Double, DbType::MySql | DbType::Sqlite) => "double".to_owned(),
|
||||
(ColumnType::Double, DbType::Postgres) => "double precision".to_owned(),
|
||||
(ColumnType::Decimal(_), DbType::MySql | DbType::Postgres) => "decimal".to_owned(),
|
||||
(ColumnType::Decimal(_), DbType::Sqlite) => "real".to_owned(),
|
||||
(ColumnType::DateTime, DbType::MySql) => "datetime".to_owned(),
|
||||
(ColumnType::DateTime, DbType::Postgres) => "timestamp w/o tz".to_owned(),
|
||||
(ColumnType::DateTime, DbType::Sqlite) => "datetime_text".to_owned(),
|
||||
(ColumnType::Timestamp, DbType::MySql | DbType::Postgres) => "timestamp".to_owned(),
|
||||
(ColumnType::Timestamp, DbType::Sqlite) => "timestamp_text".to_owned(),
|
||||
(ColumnType::TimestampWithTimeZone, DbType::MySql) => "timestamp".to_owned(),
|
||||
(ColumnType::TimestampWithTimeZone, DbType::Postgres) => {
|
||||
"timestamp w tz".to_owned()
|
||||
}
|
||||
(ColumnType::TimestampWithTimeZone, DbType::Sqlite) => {
|
||||
"timestamp_with_timezone_text".to_owned()
|
||||
}
|
||||
(ColumnType::Time, DbType::MySql | DbType::Postgres) => "time".to_owned(),
|
||||
(ColumnType::Time, DbType::Sqlite) => "time_text".to_owned(),
|
||||
(ColumnType::Date, DbType::MySql | DbType::Postgres) => "date".to_owned(),
|
||||
(ColumnType::Date, DbType::Sqlite) => "date_text".to_owned(),
|
||||
(ColumnType::Year, DbType::MySql) => "year".to_owned(),
|
||||
(ColumnType::Interval(_, _), DbType::Postgres) => "interval".to_owned(),
|
||||
(ColumnType::Blob, DbType::MySql | DbType::Sqlite) => "blob".to_owned(),
|
||||
(ColumnType::Blob, DbType::Postgres) => "bytea".to_owned(),
|
||||
(ColumnType::Binary(_), DbType::MySql) => "binary".to_owned(),
|
||||
(ColumnType::Binary(_), DbType::Postgres) => "bytea".to_owned(),
|
||||
(ColumnType::Binary(_), DbType::Sqlite) => "blob".to_owned(),
|
||||
(ColumnType::VarBinary(_), DbType::MySql) => "varbinary".to_owned(),
|
||||
(ColumnType::VarBinary(_), DbType::Postgres) => "bytea".to_owned(),
|
||||
(ColumnType::VarBinary(_), DbType::Sqlite) => "varbinary_blob".to_owned(),
|
||||
(ColumnType::Bit(_), DbType::MySql | DbType::Postgres) => "bit".to_owned(),
|
||||
(ColumnType::VarBit(_), DbType::MySql) => "bit".to_owned(),
|
||||
(ColumnType::VarBit(_), DbType::Postgres) => "varbit".to_owned(),
|
||||
(ColumnType::Boolean, DbType::MySql | DbType::Postgres) => "bool".to_owned(),
|
||||
(ColumnType::Boolean, DbType::Sqlite) => "boolean".to_owned(),
|
||||
(ColumnType::Money(_), DbType::MySql) => "decimal".to_owned(),
|
||||
(ColumnType::Money(_), DbType::Postgres) => "money".to_owned(),
|
||||
(ColumnType::Money(_), DbType::Sqlite) => "real_money".to_owned(),
|
||||
(ColumnType::Json, DbType::MySql | DbType::Postgres) => "json".to_owned(),
|
||||
(ColumnType::Json, DbType::Sqlite) => "json_text".to_owned(),
|
||||
(ColumnType::JsonBinary, DbType::MySql) => "json".to_owned(),
|
||||
(ColumnType::JsonBinary, DbType::Postgres) => "jsonb".to_owned(),
|
||||
(ColumnType::JsonBinary, DbType::Sqlite) => "jsonb_text".to_owned(),
|
||||
(ColumnType::Uuid, DbType::MySql) => "binary(16)".to_owned(),
|
||||
(ColumnType::Uuid, DbType::Postgres) => "uuid".to_owned(),
|
||||
(ColumnType::Uuid, DbType::Sqlite) => "uuid_text".to_owned(),
|
||||
(ColumnType::Enum { name, .. }, DbType::MySql) => {
|
||||
format!("ENUM({})", name.to_string().to_upper_camel_case())
|
||||
}
|
||||
(ColumnType::Enum { name, .. }, DbType::Postgres) => {
|
||||
name.to_string().to_uppercase()
|
||||
}
|
||||
(ColumnType::Enum { .. }, DbType::Sqlite) => "enum_text".to_owned(),
|
||||
(ColumnType::Array(column_type), DbType::Postgres) => {
|
||||
format!("{}[]", write_db_type(column_type, db_type)).to_uppercase()
|
||||
}
|
||||
(ColumnType::Vector(_), DbType::Postgres) => "vector".to_owned(),
|
||||
(ColumnType::Cidr, DbType::Postgres) => "cidr".to_owned(),
|
||||
(ColumnType::Inet, DbType::Postgres) => "inet".to_owned(),
|
||||
(ColumnType::MacAddr, DbType::Postgres) => "macaddr".to_owned(),
|
||||
(ColumnType::LTree, DbType::Postgres) => "ltree".to_owned(),
|
||||
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
write_db_type(&self.col_type, db_type)
|
||||
}
|
||||
pub fn get_rs_type(&self, date_time_crate: &DateTimeCrate) -> String {
|
||||
fn write_rs_type(col_type: &ColumnType, date_time_crate: &DateTimeCrate) -> String {
|
||||
#[allow(unreachable_patterns)]
|
||||
match col_type {
|
||||
ColumnType::Char(_)
|
||||
| ColumnType::String(_)
|
||||
| ColumnType::Text
|
||||
| ColumnType::Custom(_) => "String".to_owned(),
|
||||
ColumnType::TinyInteger => "i8".to_owned(),
|
||||
ColumnType::SmallInteger => "i16".to_owned(),
|
||||
ColumnType::Integer => "i32".to_owned(),
|
||||
ColumnType::BigInteger => "i64".to_owned(),
|
||||
ColumnType::TinyUnsigned => "u8".to_owned(),
|
||||
ColumnType::SmallUnsigned => "u16".to_owned(),
|
||||
ColumnType::Unsigned => "u32".to_owned(),
|
||||
ColumnType::BigUnsigned => "u64".to_owned(),
|
||||
ColumnType::Float => "f32".to_owned(),
|
||||
ColumnType::Double => "f64".to_owned(),
|
||||
ColumnType::Json | ColumnType::JsonBinary => "Json".to_owned(),
|
||||
ColumnType::Date => match date_time_crate {
|
||||
DateTimeCrate::Chrono => "Date".to_owned(),
|
||||
DateTimeCrate::Time => "TimeDate".to_owned(),
|
||||
},
|
||||
ColumnType::Time => match date_time_crate {
|
||||
DateTimeCrate::Chrono => "Time".to_owned(),
|
||||
DateTimeCrate::Time => "TimeTime".to_owned(),
|
||||
},
|
||||
ColumnType::DateTime => match date_time_crate {
|
||||
DateTimeCrate::Chrono => "DateTime".to_owned(),
|
||||
DateTimeCrate::Time => "TimeDateTime".to_owned(),
|
||||
},
|
||||
ColumnType::Timestamp => match date_time_crate {
|
||||
DateTimeCrate::Chrono => "DateTimeUtc".to_owned(),
|
||||
DateTimeCrate::Time => "TimeDateTime".to_owned(),
|
||||
},
|
||||
ColumnType::TimestampWithTimeZone => match date_time_crate {
|
||||
DateTimeCrate::Chrono => "DateTimeWithTimeZone".to_owned(),
|
||||
DateTimeCrate::Time => "TimeDateTimeWithTimeZone".to_owned(),
|
||||
},
|
||||
ColumnType::Decimal(_) | ColumnType::Money(_) => "Decimal".to_owned(),
|
||||
ColumnType::Uuid => "Uuid".to_owned(),
|
||||
ColumnType::Binary(_) | ColumnType::VarBinary(_) | ColumnType::Blob => {
|
||||
"Vec<u8>".to_owned()
|
||||
}
|
||||
ColumnType::Boolean => "bool".to_owned(),
|
||||
ColumnType::Enum { name, .. } => name.to_string().to_upper_camel_case(),
|
||||
ColumnType::Array(column_type) => {
|
||||
format!("Vec<{}>", write_rs_type(column_type, date_time_crate))
|
||||
}
|
||||
ColumnType::Vector(_) => "::pgvector::Vector".to_owned(),
|
||||
ColumnType::Bit(None | Some(1)) => "bool".to_owned(),
|
||||
ColumnType::Bit(_) | ColumnType::VarBit(_) => "Vec<u8>".to_owned(),
|
||||
ColumnType::Year => "i32".to_owned(),
|
||||
ColumnType::Cidr | ColumnType::Inet => "IpNetwork".to_owned(),
|
||||
ColumnType::Interval(_, _) | ColumnType::MacAddr | ColumnType::LTree => {
|
||||
"String".to_owned()
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
write_rs_type(&self.col_type, date_time_crate)
|
||||
}
|
||||
}
|
||||
154
src/generator/models/comment.rs
Normal file
154
src/generator/models/comment.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
use crate::generator::models::{CommentConfig, CommentConfigSerde};
|
||||
|
||||
use super::{table::Table, ModelConfig};
|
||||
use color_eyre::{eyre, owo_colors::colors::White, Result};
|
||||
use comfy_table::{ContentArrangement, Table as CTable};
|
||||
use comment_parser::{CommentParser, Event};
|
||||
|
||||
const HEADER: &str = r#"== Schema Information"#;
|
||||
const COMMENTHEAD: &str = r#"/*"#;
|
||||
const COMMENTBODY: &str = r#" *"#;
|
||||
const COMMENTTAIL: &str = r#"*/"#;
|
||||
const SETTINGSDELIMITER: &str = r#"```"#;
|
||||
|
||||
pub struct ModelCommentGenerator {}
|
||||
|
||||
impl ModelCommentGenerator {
|
||||
pub fn find_settings_block(file_content: &str) -> Option<String> {
|
||||
let delimiter_length = SETTINGSDELIMITER.len();
|
||||
let start_pos = file_content.find(SETTINGSDELIMITER)?;
|
||||
let end_pos = file_content[start_pos + delimiter_length..].find(SETTINGSDELIMITER)?;
|
||||
let content = &file_content[start_pos + delimiter_length..start_pos + end_pos];
|
||||
let content = content.replace(&format!("\n{COMMENTBODY}"), "\n");
|
||||
Some(content)
|
||||
}
|
||||
pub fn generate_comment(
|
||||
table: Table,
|
||||
file_content: &str,
|
||||
config: &ModelConfig,
|
||||
) -> Result<String> {
|
||||
let rules = comment_parser::get_syntax("rust").unwrap();
|
||||
let parser = CommentParser::new(&file_content, rules);
|
||||
for comment in parser {
|
||||
if let Event::BlockComment(body, _) = comment {
|
||||
if body.contains(HEADER) {
|
||||
tracing::debug!("Found header");
|
||||
let mut settings = config.comment.clone();
|
||||
let mut new_settings = None;
|
||||
if let Some(parsed_settings) = Self::find_settings_block(file_content) {
|
||||
tracing::info!(?new_settings);
|
||||
match serde_yaml::from_str::<CommentConfigSerde>(&parsed_settings) {
|
||||
Ok(s) => {
|
||||
new_settings = Some(s.clone());
|
||||
settings = s.merge(&settings);
|
||||
tracing::info!(?settings);
|
||||
}
|
||||
Err(e) => {
|
||||
if !settings.ignore_errors {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::debug!(?table.name, ?settings);
|
||||
if settings.enable {
|
||||
let comment =
|
||||
Self::generate_comment_content(table, config, &settings, new_settings)?;
|
||||
return Ok(file_content.replace(body, &comment));
|
||||
}
|
||||
|
||||
// let settings = settings.unwrap();
|
||||
// tracing::info!(?settings);
|
||||
// let merged_settings = settings.merge(&config.comment);
|
||||
// if merged_settings.enable {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let comment = Self::generate_comment_content(table, config, &config.comment, None)?;
|
||||
Ok(format!("{}\n{}", comment, file_content))
|
||||
}
|
||||
pub fn generate_comment_content(
|
||||
table: Table,
|
||||
model_config: &ModelConfig,
|
||||
config: &CommentConfig,
|
||||
parsed_settings: Option<CommentConfigSerde>,
|
||||
) -> Result<String> {
|
||||
let mut model_config = model_config.clone();
|
||||
model_config.comment = config.clone();
|
||||
let column_info_table = if config.column_info {
|
||||
let mut column_info_table = CTable::new();
|
||||
let mut header = Vec::new();
|
||||
if config.column_name {
|
||||
header.push("Name");
|
||||
}
|
||||
if config.column_db_type {
|
||||
header.push("DbType");
|
||||
}
|
||||
if config.column_rust_type {
|
||||
header.push("RsType");
|
||||
}
|
||||
if config.column_attributes {
|
||||
header.push("Attrs");
|
||||
}
|
||||
column_info_table
|
||||
.load_preset(" -+=++ + ++")
|
||||
.set_content_arrangement(ContentArrangement::Dynamic)
|
||||
.set_header(header);
|
||||
if let Some(width) = config.max_width {
|
||||
column_info_table.set_width(width);
|
||||
}
|
||||
for column in &table.columns {
|
||||
column_info_table.add_row(column.get_info_row(&model_config)?);
|
||||
}
|
||||
column_info_table.to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
let config_part = match parsed_settings {
|
||||
Some(settings) => {
|
||||
let settings_str = serde_yaml::to_string(&settings)?;
|
||||
let settings_str = settings_str
|
||||
.lines()
|
||||
.map(|line| format!(" {}", line))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
format!(
|
||||
"{SETTINGSDELIMITER}\n{}\n{SETTINGSDELIMITER}\n\n",
|
||||
settings_str
|
||||
)
|
||||
}
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
let table_name = &table.name;
|
||||
let table_name_str = if config.table_name {
|
||||
format!("Table: {}\n", table_name)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
let string = format!("{HEADER}\n{config_part}{table_name_str}\n{column_info_table}");
|
||||
|
||||
let padded_string = Self::pad_comment(&string);
|
||||
Ok(padded_string)
|
||||
}
|
||||
|
||||
pub fn pad_comment(s: &str) -> String {
|
||||
let parts = s.split('\n').collect::<Vec<_>>();
|
||||
let mut padded = String::new();
|
||||
for (index, part) in parts.iter().enumerate() {
|
||||
let first = index == 0;
|
||||
let comment = match first {
|
||||
true => COMMENTHEAD.to_string(),
|
||||
false => COMMENTBODY.to_string(),
|
||||
};
|
||||
let padded_part = format!("{} {}\n", comment, part);
|
||||
padded.push_str(&padded_part);
|
||||
}
|
||||
padded.push_str(COMMENTTAIL);
|
||||
padded
|
||||
}
|
||||
// pub async fn generate_header(&self, config: &Config, db_type: &DbType) -> Result<String> {
|
||||
//
|
||||
// }
|
||||
}
|
||||
160
src/generator/models/discover.rs
Normal file
160
src/generator/models/discover.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use core::time;
|
||||
|
||||
use color_eyre::eyre::{eyre, Context, ContextCompat, Report, Result};
|
||||
use sea_schema::sea_query::TableCreateStatement;
|
||||
use url::Url;
|
||||
|
||||
use crate::config::db::DbConfig;
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DbType {
|
||||
MySql,
|
||||
Postgres,
|
||||
Sqlite,
|
||||
}
|
||||
|
||||
pub async fn get_tables(
|
||||
database_url: String,
|
||||
filter: Box<dyn Fn(&String) -> bool>,
|
||||
database_config: &DbConfig,
|
||||
) -> Result<(Vec<TableCreateStatement>, DbType)> {
|
||||
let url = Url::parse(&database_url)?;
|
||||
|
||||
tracing::trace!(?url);
|
||||
|
||||
let is_sqlite = url.scheme() == "sqlite";
|
||||
|
||||
let database_name: &str = (if !is_sqlite {
|
||||
let database_name = url
|
||||
.path_segments()
|
||||
.context("No database name as part of path")?
|
||||
.next()
|
||||
.context("No database name as part of path")?;
|
||||
|
||||
if database_name.is_empty() {
|
||||
return Err(eyre!("Database path name is empty"));
|
||||
}
|
||||
Ok::<&str, Report>(database_name)
|
||||
} else {
|
||||
Ok(Default::default())
|
||||
})?;
|
||||
|
||||
let (table_stmts, db_type) = match url.scheme() {
|
||||
"mysql" => {
|
||||
use sea_schema::mysql::discovery::SchemaDiscovery;
|
||||
use sqlx::MySql;
|
||||
|
||||
tracing::info!("Connecting to MySQL");
|
||||
let connection = sqlx_connect::<MySql>(
|
||||
database_config.max_connections,
|
||||
database_config.acquire_timeout,
|
||||
url.as_str(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
tracing::info!("Discovering schema");
|
||||
let schema_discovery = SchemaDiscovery::new(connection, database_name);
|
||||
let schema = schema_discovery.discover().await?;
|
||||
let table_stmts = schema
|
||||
.tables
|
||||
.into_iter()
|
||||
.filter(|schema| filter(&schema.info.name))
|
||||
// .filter(|schema| filter_hidden_tables(&schema.info.name))
|
||||
// .filter(|schema| filter_skip_tables(&schema.info.name))
|
||||
.map(|schema| schema.write())
|
||||
.collect();
|
||||
(table_stmts, DbType::MySql)
|
||||
}
|
||||
"sqlite" => {
|
||||
use sea_schema::sqlite::discovery::SchemaDiscovery;
|
||||
use sqlx::Sqlite;
|
||||
|
||||
tracing::info!("Connecting to SQLite");
|
||||
let connection = sqlx_connect::<Sqlite>(
|
||||
database_config.max_connections,
|
||||
database_config.acquire_timeout,
|
||||
url.as_str(),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
tracing::info!("Discovering schema");
|
||||
let schema_discovery = SchemaDiscovery::new(connection);
|
||||
let schema = schema_discovery
|
||||
.discover()
|
||||
.await?
|
||||
.merge_indexes_into_table();
|
||||
let table_stmts = schema
|
||||
.tables
|
||||
.into_iter()
|
||||
.filter(|schema| filter(&schema.name))
|
||||
// .filter(|schema| filter_hidden_tables(&schema.name))
|
||||
// .filter(|schema| filter_skip_tables(&schema.name))
|
||||
.map(|schema| schema.write())
|
||||
.collect();
|
||||
(table_stmts, DbType::Sqlite)
|
||||
}
|
||||
"postgres" | "potgresql" => {
|
||||
use sea_schema::postgres::discovery::SchemaDiscovery;
|
||||
use sqlx::Postgres;
|
||||
|
||||
tracing::info!("Connecting to Postgres");
|
||||
let schema = &database_config
|
||||
.database_schema
|
||||
.as_deref()
|
||||
.unwrap_or("public");
|
||||
let connection = sqlx_connect::<Postgres>(
|
||||
database_config.max_connections,
|
||||
database_config.acquire_timeout,
|
||||
url.as_str(),
|
||||
Some(schema),
|
||||
)
|
||||
.await?;
|
||||
tracing::info!("Discovering schema");
|
||||
let schema_discovery = SchemaDiscovery::new(connection, schema);
|
||||
let schema = schema_discovery.discover().await?;
|
||||
tracing::info!(?schema);
|
||||
let table_stmts = schema
|
||||
.tables
|
||||
.into_iter()
|
||||
.filter(|schema| filter(&schema.info.name))
|
||||
// .filter(|schema| filter_hidden_tables(&schema.info.name))
|
||||
// .filter(|schema| filter_skip_tables(&schema.info.name))
|
||||
.map(|schema| schema.write())
|
||||
.collect();
|
||||
(table_stmts, DbType::Postgres)
|
||||
}
|
||||
_ => unimplemented!("{} is not supported", url.scheme()),
|
||||
};
|
||||
tracing::info!("Schema discovered");
|
||||
|
||||
Ok((table_stmts, db_type))
|
||||
}
|
||||
async fn sqlx_connect<DB>(
|
||||
max_connections: u32,
|
||||
acquire_timeout: u64,
|
||||
url: &str,
|
||||
schema: Option<&str>,
|
||||
) -> Result<sqlx::Pool<DB>>
|
||||
where
|
||||
DB: sqlx::Database,
|
||||
for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
|
||||
{
|
||||
let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
|
||||
.max_connections(max_connections)
|
||||
.acquire_timeout(time::Duration::from_secs(acquire_timeout));
|
||||
// Set search_path for Postgres, E.g. Some("public") by default
|
||||
// MySQL & SQLite connection initialize with schema `None`
|
||||
if let Some(schema) = schema {
|
||||
let sql = format!("SET search_path = '{schema}'");
|
||||
pool_options = pool_options.after_connect(move |conn, _| {
|
||||
let sql = sql.clone();
|
||||
Box::pin(async move {
|
||||
sqlx::Executor::execute(conn, sql.as_str())
|
||||
.await
|
||||
.map(|_| ())
|
||||
})
|
||||
});
|
||||
}
|
||||
pool_options.connect(url).await.map_err(Into::into)
|
||||
}
|
||||
101
src/generator/models/file.rs
Normal file
101
src/generator/models/file.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::{config::Config, generator::file::GeneratedFileChunk};
|
||||
|
||||
use super::{
|
||||
comment::ModelCommentGenerator, discover::DbType, table::Table, CommentConfig, ModelConfig,
|
||||
};
|
||||
use color_eyre::Result;
|
||||
use comfy_table::{ContentArrangement, Table as CTable};
|
||||
use comment_parser::{CommentParser, Event};
|
||||
use handlebars::Handlebars;
|
||||
use sea_orm_codegen::OutputFile;
|
||||
use serde::Serialize;
|
||||
use tokio::fs;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileGenerator {
|
||||
filename: String,
|
||||
table: Table,
|
||||
}
|
||||
// #[derive(Debug, Clone, Serialize)]
|
||||
// pub struct FileContext {
|
||||
// entities_path: String,
|
||||
// model_path: Option<String>,
|
||||
// model_name: Option<String>,
|
||||
// active_model_name: Option<String>,
|
||||
// prelude_path: Option<String>,
|
||||
// }
|
||||
|
||||
impl FileGenerator {
|
||||
pub async fn generate_file<'a>(
|
||||
table: Table,
|
||||
config: &ModelConfig,
|
||||
handlebars: &'a Handlebars<'a>,
|
||||
) -> Result<Vec<GeneratedFileChunk>> {
|
||||
let mut file_chunks = Vec::new();
|
||||
file_chunks.push(GeneratedFileChunk {
|
||||
path: config.models_path.join("mod.rs"),
|
||||
content: format!("pub mod {};", table.name),
|
||||
priority: 0,
|
||||
});
|
||||
let filepath = config.models_path.join(format!("{}.rs", table.name));
|
||||
tracing::debug!(?filepath, "Generating file");
|
||||
if filepath.exists() {
|
||||
file_chunks
|
||||
.extend(Self::handle_existing_file(table, &filepath, config, handlebars).await?);
|
||||
} else {
|
||||
}
|
||||
|
||||
// let filepath = config.output.path.join(&self.filename);
|
||||
// let file_context = FileContext {
|
||||
// entities_path: self.entities_path.clone(),
|
||||
// model_name: self.table.name.clone(),
|
||||
// };
|
||||
// let generated_header = self.generate_header(config, db_type).await?;
|
||||
// if filepath.exists() {
|
||||
// let mut file_content = fs::read_to_string(filepath).await?;
|
||||
// if !config.output.models.comment.enable {
|
||||
// return Ok(file_content);
|
||||
// }
|
||||
// let rules = comment_parser::get_syntax("rust").unwrap();
|
||||
// let parser = CommentParser::new(&file_content, rules);
|
||||
// for comment in parser {
|
||||
// if let Event::BlockComment(body, _) = comment {
|
||||
// if body.contains(HEADER) {
|
||||
// tracing::debug!("Found header");
|
||||
// file_content = file_content.replace(body, &generated_header);
|
||||
// return Ok(file_content);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// Ok(format!("{}\n{}", generated_header, file_content))
|
||||
// } else {
|
||||
// let content = handlebars.render("model", &file_context)?;
|
||||
// Ok(format!("{}{}", generated_header, content))
|
||||
// }
|
||||
|
||||
// Ok(OutputFile {
|
||||
// name: self.filename.clone(),
|
||||
// content: self.generate_file(config, handlebars, db_type).await?,
|
||||
// })
|
||||
Ok(file_chunks)
|
||||
}
|
||||
async fn handle_existing_file<'a>(
|
||||
table: Table,
|
||||
filepath: &PathBuf,
|
||||
config: &ModelConfig,
|
||||
handlebars: &'a Handlebars<'a>,
|
||||
) -> Result<Vec<GeneratedFileChunk>> {
|
||||
let mut file_chunks = Vec::new();
|
||||
let mut file_content = fs::read_to_string(filepath).await?;
|
||||
if config.comment.enable {
|
||||
file_content = ModelCommentGenerator::generate_comment(table, &file_content, config)?;
|
||||
}
|
||||
file_chunks.push(GeneratedFileChunk {
|
||||
path: filepath.clone(),
|
||||
content: file_content,
|
||||
priority: 0,
|
||||
});
|
||||
Ok(file_chunks)
|
||||
}
|
||||
}
|
||||
153
src/generator/models/mod.rs
Normal file
153
src/generator/models/mod.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
use crate::config::{sea_orm_config::DateTimeCrate, Config};
|
||||
use color_eyre::Result;
|
||||
use discover::DbType;
|
||||
use file::FileGenerator;
|
||||
use handlebars::Handlebars;
|
||||
use sea_orm_codegen::{EntityTransformer, EntityWriterContext, OutputFile};
|
||||
use sea_schema::sea_query::TableCreateStatement;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use table::Table;
|
||||
|
||||
use super::file::{GeneratedFile, GeneratedFileChunk};
|
||||
pub mod column;
|
||||
pub mod comment;
|
||||
pub mod discover;
|
||||
pub mod file;
|
||||
pub mod table;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelConfig {
|
||||
pub models_path: PathBuf,
|
||||
pub prelude: bool,
|
||||
pub entities_path: PathBuf,
|
||||
pub relative_entities_path: String,
|
||||
pub enable: bool,
|
||||
pub comment: CommentConfig,
|
||||
pub db_type: DbType,
|
||||
}
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CommentConfig {
|
||||
pub max_width: Option<u16>,
|
||||
pub enable: bool,
|
||||
pub table_name: bool,
|
||||
pub column_info: bool,
|
||||
pub column_name: bool,
|
||||
pub column_rust_type: bool,
|
||||
pub column_db_type: bool,
|
||||
pub column_attributes: bool,
|
||||
pub ignore_errors: bool,
|
||||
pub date_time_crate: DateTimeCrate,
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CommentConfigSerde {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_width: Option<u16>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enable: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub table_name: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub info: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rust_type: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub db_type: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub attributes: Option<bool>,
|
||||
}
|
||||
impl CommentConfigSerde {
|
||||
pub fn merge(&self, config: &CommentConfig) -> CommentConfig {
|
||||
CommentConfig {
|
||||
max_width: self.max_width.or(config.max_width),
|
||||
table_name: self.table_name.unwrap_or(config.table_name),
|
||||
column_name: self.name.unwrap_or(config.column_name),
|
||||
column_info: self.info.unwrap_or(config.column_info),
|
||||
column_rust_type: self.rust_type.unwrap_or(config.column_rust_type),
|
||||
column_db_type: self.db_type.unwrap_or(config.column_db_type),
|
||||
column_attributes: self.attributes.unwrap_or(config.column_attributes),
|
||||
ignore_errors: config.ignore_errors,
|
||||
enable: self.enable.unwrap_or(config.enable),
|
||||
date_time_crate: config.date_time_crate.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelConfig {
|
||||
pub fn new(config: Config, db_type: DbType) -> Self {
|
||||
let models_path = config.output.path.join(&config.output.models.path);
|
||||
let entities_path = models_path.join(&config.output.models.entities);
|
||||
ModelConfig {
|
||||
db_type,
|
||||
prelude: config.output.models.prelude,
|
||||
entities_path,
|
||||
models_path,
|
||||
relative_entities_path: config.output.models.entities.clone(),
|
||||
enable: config.output.models.enable,
|
||||
comment: CommentConfig {
|
||||
max_width: config.output.models.comment.max_width,
|
||||
enable: config.output.models.comment.enable,
|
||||
table_name: config.output.models.comment.table_name,
|
||||
column_name: config.output.models.comment.column_name,
|
||||
column_info: config.output.models.comment.column_info,
|
||||
column_rust_type: config.output.models.comment.column_rust_type,
|
||||
column_db_type: config.output.models.comment.column_db_type,
|
||||
column_attributes: config.output.models.comment.column_attributes,
|
||||
ignore_errors: config.output.models.comment.ignore_errors,
|
||||
date_time_crate: config.sea_orm.entity.date_time_crate,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn generate_models<'a>(
|
||||
database_url: &str,
|
||||
config: &Config,
|
||||
handlebars: &'a Handlebars<'a>,
|
||||
) -> Result<Vec<GeneratedFileChunk>> {
|
||||
let mut files = Vec::new();
|
||||
let db_filter = config.sea_orm.entity.tables.get_filter();
|
||||
let (table_stmts, db_type) =
|
||||
discover::get_tables(database_url.to_owned(), db_filter, &config.db).await?;
|
||||
let model_config = ModelConfig::new(config.clone(), db_type);
|
||||
|
||||
let writer_context = config.clone().into();
|
||||
files.extend(
|
||||
generate_entities(table_stmts.clone(), model_config.clone(), writer_context).await?,
|
||||
);
|
||||
|
||||
files.push(GeneratedFileChunk {
|
||||
path: model_config.models_path.join("mod.rs"),
|
||||
content: format!("pub mod {};", model_config.relative_entities_path),
|
||||
priority: 0,
|
||||
});
|
||||
let tables = table_stmts
|
||||
.into_iter()
|
||||
.map(Table::new)
|
||||
.collect::<Result<Vec<Table>>>()?;
|
||||
|
||||
if model_config.enable {
|
||||
for table in tables {
|
||||
files.extend(FileGenerator::generate_file(table, &model_config, handlebars).await?);
|
||||
}
|
||||
}
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
pub async fn generate_entities(
|
||||
table_statements: Vec<TableCreateStatement>,
|
||||
config: ModelConfig,
|
||||
writer_context: EntityWriterContext,
|
||||
) -> Result<Vec<GeneratedFileChunk>> {
|
||||
let output = EntityTransformer::transform(table_statements)?.generate(&writer_context);
|
||||
Ok(output
|
||||
.files
|
||||
.into_iter()
|
||||
.map(|OutputFile { name, content }| GeneratedFileChunk {
|
||||
path: config.entities_path.join(name),
|
||||
content,
|
||||
priority: 0,
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
52
src/generator/models/table.rs
Normal file
52
src/generator/models/table.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use super::column::Column;
|
||||
use crate::config::Config;
|
||||
use color_eyre::{eyre::eyre, Result};
|
||||
use sea_schema::sea_query::{self, ColumnDef, IndexCreateStatement, TableCreateStatement};
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Table {
|
||||
pub name: String,
|
||||
pub columns: Vec<Column>,
|
||||
}
|
||||
|
||||
impl Table {
|
||||
pub fn new(statement: TableCreateStatement) -> Result<Table> {
|
||||
let table_name = match statement.get_table_name() {
|
||||
Some(table_ref) => match table_ref {
|
||||
sea_query::TableRef::Table(t)
|
||||
| sea_query::TableRef::SchemaTable(_, t)
|
||||
| sea_query::TableRef::DatabaseSchemaTable(_, _, t)
|
||||
| sea_query::TableRef::TableAlias(t, _)
|
||||
| sea_query::TableRef::SchemaTableAlias(_, t, _)
|
||||
| sea_query::TableRef::DatabaseSchemaTableAlias(_, _, t, _) => t.to_string(),
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
None => return Err(eyre!("Table name not found")),
|
||||
};
|
||||
tracing::debug!(?table_name);
|
||||
let columns_raw = statement.get_columns();
|
||||
let indexes = statement.get_indexes();
|
||||
for column in columns_raw {
|
||||
tracing::debug!(?column);
|
||||
}
|
||||
for index in indexes {
|
||||
tracing::debug!(?index);
|
||||
}
|
||||
let columns = columns_raw
|
||||
.iter()
|
||||
.map(|column| {
|
||||
let name = column.get_column_name();
|
||||
let index = indexes
|
||||
.iter()
|
||||
.find(|index| index.get_index_spec().get_column_names().contains(&name));
|
||||
Column::new(column.clone(), index.cloned())
|
||||
})
|
||||
.collect::<Result<Vec<Column>>>()?;
|
||||
tracing::debug!(?columns);
|
||||
Ok(Table {
|
||||
columns,
|
||||
name: table_name,
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user