From 17b1994b285e162f6809516bfb5449ecc3223b5b Mon Sep 17 00:00:00 2001 From: Nikkuss Date: Tue, 8 Apr 2025 08:13:41 +0400 Subject: [PATCH] fix discovery --- flake.nix | 30 +-- src/generator/modules/discovery/column.rs | 236 ++++++++++++++++++++++ src/generator/modules/discovery/mod.rs | 17 +- src/generator/modules/discovery/table.rs | 50 +++++ 4 files changed, 317 insertions(+), 16 deletions(-) create mode 100644 src/generator/modules/discovery/column.rs create mode 100644 src/generator/modules/discovery/table.rs diff --git a/flake.nix b/flake.nix index 0692ceb..a1ae41e 100644 --- a/flake.nix +++ b/flake.nix @@ -109,21 +109,21 @@ name = "process-compose"; config = (import ./process-compose.nix { inherit pkgs; }); # enableTui = true; - # modules = [ - # (process-compose.mkPostgres { - # name = "postgres"; - # initialDatabases = [ - # { - # name = "db"; - # user = "root"; - # password = "root"; - # } - # ]; - # }) - # (process-compose.mkRedis { - # name = "redis"; - # }) - # ]; + modules = [ + (process-compose.mkPostgres { + name = "postgres"; + initialDatabases = [ + { + name = "db"; + user = "root"; + password = "root"; + } + ]; + }) + (process-compose.mkRedis { + name = "redis"; + }) + ]; }; llvm-coverage = craneLib.cargoLlvmCov ( commonArgs diff --git a/src/generator/modules/discovery/column.rs b/src/generator/modules/discovery/column.rs new file mode 100644 index 0000000..97dfe82 --- /dev/null +++ b/src/generator/modules/discovery/column.rs @@ -0,0 +1,236 @@ +use color_eyre::{eyre::ContextCompat, Result}; +use heck::ToUpperCamelCase; +use sea_schema::sea_query::{ColumnDef, ColumnSpec, ColumnType, IndexCreateStatement}; + +use crate::config::sea_orm_config::DateTimeCrate; + +use super::db::DbType; +#[derive(Clone, Debug)] +pub struct Column { + pub name: String, + pub col_type: ColumnType, + pub attrs: Vec, +} + +impl Column { + pub fn new(column: ColumnDef, index: Option) -> Result { + let name = column.get_column_name(); + let col_type = column + .get_column_type() + .context("Unable to get column type")? + .clone(); + let mut attrs = column.get_column_spec().clone(); + if let Some(index) = index { + if index.is_unique_key() { + attrs.push(ColumnSpec::UniqueKey) + } + if index.is_primary_key() { + attrs.push(ColumnSpec::PrimaryKey); + } + } + Ok(Column { + name: name.to_string(), + col_type, + attrs: attrs.to_vec(), + }) + } + // pub fn get_info_row(&self, config: &ModelConfig) -> Result> { + // let column_type_rust = self.get_rs_type(&config.comment.date_time_crate); + // let column_type = self.get_db_type(&config.db_type); + // let attrs = self.attrs_to_string(); + // let mut cols = Vec::new(); + // if config.comment.column_name { + // cols.push(Cell::new(self.name.clone())) + // } + // if config.comment.column_name { + // cols.push(Cell::new(column_type.clone())) + // } + // if config.comment.column_rust_type { + // cols.push(Cell::new(column_type_rust.clone())) + // } + // if config.comment.column_attributes { + // cols.push(Cell::new(attrs.clone())); + // } + // Ok(cols) + // } + pub fn attrs_to_string(&self) -> String { + self.attrs + .iter() + .filter_map(Self::get_addr_type) + .map(|s| s.to_string()) + .collect::>() + .join(", ") + } + pub fn get_addr_type(attr: &ColumnSpec) -> Option { + match attr { + ColumnSpec::PrimaryKey => Some("primary key".to_owned()), + ColumnSpec::Null => unimplemented!(), + ColumnSpec::NotNull => Some("not null".to_owned()), + ColumnSpec::Default(_) => unimplemented!(), + ColumnSpec::AutoIncrement => Some("autoincrement".to_owned()), + ColumnSpec::UniqueKey => Some("unique key".to_owned()), + ColumnSpec::Check(_) => unimplemented!(), + ColumnSpec::Generated { .. } => unimplemented!(), + ColumnSpec::Extra(_) => unimplemented!(), + ColumnSpec::Comment(_) => unimplemented!(), + ColumnSpec::Using(_) => unimplemented!(), + } + } + pub fn get_db_type(&self, db_type: &DbType) -> String { + fn write_db_type(col_type: &ColumnType, db_type: &DbType) -> String { + #[allow(unreachable_patterns)] + match (col_type, db_type) { + (ColumnType::Char(_), _) => "char".to_owned(), + (ColumnType::String(_), _) => "varchar".to_owned(), + (ColumnType::Text, _) => "text".to_owned(), + (ColumnType::TinyInteger, DbType::MySql | DbType::Sqlite) => "tinyint".to_owned(), + (ColumnType::TinyInteger, DbType::Postgres) => "smallint".to_owned(), + (ColumnType::SmallInteger, _) => "smallint".to_owned(), + (ColumnType::Integer, DbType::MySql) => "int".to_owned(), + (ColumnType::Integer, _) => "integer".to_owned(), + (ColumnType::BigInteger, DbType::MySql | DbType::Postgres) => "bigint".to_owned(), + (ColumnType::BigInteger, DbType::Sqlite) => "integer".to_owned(), + (ColumnType::TinyUnsigned, DbType::MySql) => "tinyint unsigned".to_owned(), + (ColumnType::TinyUnsigned, DbType::Postgres) => "smallint".to_owned(), + (ColumnType::TinyUnsigned, DbType::Sqlite) => "tinyint".to_owned(), + (ColumnType::SmallUnsigned, DbType::MySql) => "smallint unsigned".to_owned(), + (ColumnType::SmallUnsigned, DbType::Postgres | DbType::Sqlite) => { + "smallint".to_owned() + } + (ColumnType::Unsigned, DbType::MySql) => "int unsigned".to_owned(), + (ColumnType::Unsigned, DbType::Postgres | DbType::Sqlite) => "integer".to_owned(), + (ColumnType::BigUnsigned, DbType::MySql) => "bigint unsigned".to_owned(), + (ColumnType::BigUnsigned, DbType::Postgres) => "bigint".to_owned(), + (ColumnType::BigUnsigned, DbType::Sqlite) => "integer".to_owned(), + (ColumnType::Float, DbType::MySql | DbType::Sqlite) => "float".to_owned(), + (ColumnType::Float, DbType::Postgres) => "real".to_owned(), + (ColumnType::Double, DbType::MySql | DbType::Sqlite) => "double".to_owned(), + (ColumnType::Double, DbType::Postgres) => "double precision".to_owned(), + (ColumnType::Decimal(_), DbType::MySql | DbType::Postgres) => "decimal".to_owned(), + (ColumnType::Decimal(_), DbType::Sqlite) => "real".to_owned(), + (ColumnType::DateTime, DbType::MySql) => "datetime".to_owned(), + (ColumnType::DateTime, DbType::Postgres) => "timestamp w/o tz".to_owned(), + (ColumnType::DateTime, DbType::Sqlite) => "datetime_text".to_owned(), + (ColumnType::Timestamp, DbType::MySql | DbType::Postgres) => "timestamp".to_owned(), + (ColumnType::Timestamp, DbType::Sqlite) => "timestamp_text".to_owned(), + (ColumnType::TimestampWithTimeZone, DbType::MySql) => "timestamp".to_owned(), + (ColumnType::TimestampWithTimeZone, DbType::Postgres) => { + "timestamp w tz".to_owned() + } + (ColumnType::TimestampWithTimeZone, DbType::Sqlite) => { + "timestamp_with_timezone_text".to_owned() + } + (ColumnType::Time, DbType::MySql | DbType::Postgres) => "time".to_owned(), + (ColumnType::Time, DbType::Sqlite) => "time_text".to_owned(), + (ColumnType::Date, DbType::MySql | DbType::Postgres) => "date".to_owned(), + (ColumnType::Date, DbType::Sqlite) => "date_text".to_owned(), + (ColumnType::Year, DbType::MySql) => "year".to_owned(), + (ColumnType::Interval(_, _), DbType::Postgres) => "interval".to_owned(), + (ColumnType::Blob, DbType::MySql | DbType::Sqlite) => "blob".to_owned(), + (ColumnType::Blob, DbType::Postgres) => "bytea".to_owned(), + (ColumnType::Binary(_), DbType::MySql) => "binary".to_owned(), + (ColumnType::Binary(_), DbType::Postgres) => "bytea".to_owned(), + (ColumnType::Binary(_), DbType::Sqlite) => "blob".to_owned(), + (ColumnType::VarBinary(_), DbType::MySql) => "varbinary".to_owned(), + (ColumnType::VarBinary(_), DbType::Postgres) => "bytea".to_owned(), + (ColumnType::VarBinary(_), DbType::Sqlite) => "varbinary_blob".to_owned(), + (ColumnType::Bit(_), DbType::MySql | DbType::Postgres) => "bit".to_owned(), + (ColumnType::VarBit(_), DbType::MySql) => "bit".to_owned(), + (ColumnType::VarBit(_), DbType::Postgres) => "varbit".to_owned(), + (ColumnType::Boolean, DbType::MySql | DbType::Postgres) => "bool".to_owned(), + (ColumnType::Boolean, DbType::Sqlite) => "boolean".to_owned(), + (ColumnType::Money(_), DbType::MySql) => "decimal".to_owned(), + (ColumnType::Money(_), DbType::Postgres) => "money".to_owned(), + (ColumnType::Money(_), DbType::Sqlite) => "real_money".to_owned(), + (ColumnType::Json, DbType::MySql | DbType::Postgres) => "json".to_owned(), + (ColumnType::Json, DbType::Sqlite) => "json_text".to_owned(), + (ColumnType::JsonBinary, DbType::MySql) => "json".to_owned(), + (ColumnType::JsonBinary, DbType::Postgres) => "jsonb".to_owned(), + (ColumnType::JsonBinary, DbType::Sqlite) => "jsonb_text".to_owned(), + (ColumnType::Uuid, DbType::MySql) => "binary(16)".to_owned(), + (ColumnType::Uuid, DbType::Postgres) => "uuid".to_owned(), + (ColumnType::Uuid, DbType::Sqlite) => "uuid_text".to_owned(), + (ColumnType::Enum { name, .. }, DbType::MySql) => { + format!("ENUM({})", name.to_string().to_upper_camel_case()) + } + (ColumnType::Enum { name, .. }, DbType::Postgres) => { + name.to_string().to_uppercase() + } + (ColumnType::Enum { .. }, DbType::Sqlite) => "enum_text".to_owned(), + (ColumnType::Array(column_type), DbType::Postgres) => { + format!("{}[]", write_db_type(column_type, db_type)).to_uppercase() + } + (ColumnType::Vector(_), DbType::Postgres) => "vector".to_owned(), + (ColumnType::Cidr, DbType::Postgres) => "cidr".to_owned(), + (ColumnType::Inet, DbType::Postgres) => "inet".to_owned(), + (ColumnType::MacAddr, DbType::Postgres) => "macaddr".to_owned(), + (ColumnType::LTree, DbType::Postgres) => "ltree".to_owned(), + + _ => unimplemented!(), + } + } + write_db_type(&self.col_type, db_type) + } + pub fn get_rs_type(&self, date_time_crate: &DateTimeCrate) -> String { + fn write_rs_type(col_type: &ColumnType, date_time_crate: &DateTimeCrate) -> String { + #[allow(unreachable_patterns)] + match col_type { + ColumnType::Char(_) + | ColumnType::String(_) + | ColumnType::Text + | ColumnType::Custom(_) => "String".to_owned(), + ColumnType::TinyInteger => "i8".to_owned(), + ColumnType::SmallInteger => "i16".to_owned(), + ColumnType::Integer => "i32".to_owned(), + ColumnType::BigInteger => "i64".to_owned(), + ColumnType::TinyUnsigned => "u8".to_owned(), + ColumnType::SmallUnsigned => "u16".to_owned(), + ColumnType::Unsigned => "u32".to_owned(), + ColumnType::BigUnsigned => "u64".to_owned(), + ColumnType::Float => "f32".to_owned(), + ColumnType::Double => "f64".to_owned(), + ColumnType::Json | ColumnType::JsonBinary => "Json".to_owned(), + ColumnType::Date => match date_time_crate { + DateTimeCrate::Chrono => "Date".to_owned(), + DateTimeCrate::Time => "TimeDate".to_owned(), + }, + ColumnType::Time => match date_time_crate { + DateTimeCrate::Chrono => "Time".to_owned(), + DateTimeCrate::Time => "TimeTime".to_owned(), + }, + ColumnType::DateTime => match date_time_crate { + DateTimeCrate::Chrono => "DateTime".to_owned(), + DateTimeCrate::Time => "TimeDateTime".to_owned(), + }, + ColumnType::Timestamp => match date_time_crate { + DateTimeCrate::Chrono => "DateTimeUtc".to_owned(), + DateTimeCrate::Time => "TimeDateTime".to_owned(), + }, + ColumnType::TimestampWithTimeZone => match date_time_crate { + DateTimeCrate::Chrono => "DateTimeWithTimeZone".to_owned(), + DateTimeCrate::Time => "TimeDateTimeWithTimeZone".to_owned(), + }, + ColumnType::Decimal(_) | ColumnType::Money(_) => "Decimal".to_owned(), + ColumnType::Uuid => "Uuid".to_owned(), + ColumnType::Binary(_) | ColumnType::VarBinary(_) | ColumnType::Blob => { + "Vec".to_owned() + } + ColumnType::Boolean => "bool".to_owned(), + ColumnType::Enum { name, .. } => name.to_string().to_upper_camel_case(), + ColumnType::Array(column_type) => { + format!("Vec<{}>", write_rs_type(column_type, date_time_crate)) + } + ColumnType::Vector(_) => "::pgvector::Vector".to_owned(), + ColumnType::Bit(None | Some(1)) => "bool".to_owned(), + ColumnType::Bit(_) | ColumnType::VarBit(_) => "Vec".to_owned(), + ColumnType::Year => "i32".to_owned(), + ColumnType::Cidr | ColumnType::Inet => "IpNetwork".to_owned(), + ColumnType::Interval(_, _) | ColumnType::MacAddr | ColumnType::LTree => { + "String".to_owned() + } + _ => unimplemented!(), + } + } + write_rs_type(&self.col_type, date_time_crate) + } +} diff --git a/src/generator/modules/discovery/mod.rs b/src/generator/modules/discovery/mod.rs index 08ab402..5d96f37 100644 --- a/src/generator/modules/discovery/mod.rs +++ b/src/generator/modules/discovery/mod.rs @@ -1,10 +1,14 @@ +pub mod column; pub mod db; +pub mod table; use crate::generator::DatabaseUrl; use super::{Module, ModulesContext}; use color_eyre::Result; +use db::DbType; use serde::Deserialize; use serde_inline_default::serde_inline_default; +use table::Table; #[derive(Debug, Clone, Deserialize)] #[serde(default)] @@ -83,8 +87,10 @@ impl DiscoveryFilterConfig { } } +#[derive(Debug, Clone)] pub struct DiscoveredSchema { pub tables: Vec, + pub database_type: DbType, } #[derive(Debug)] @@ -114,7 +120,16 @@ impl Module for DiscoveryModule { let url = url.clone(); let (stmts, db_type) = db::get_tables(url.0, config).await?; - tracing::info!(?stmts, ?db_type); + let tables = stmts + .into_iter() + .map(Table::new) + .collect::>>()?; + tracing::info!(?tables, ?db_type); + let discovered = DiscoveredSchema { + tables, + database_type: db_type, + }; + ctx.get_anymap_mut().insert(discovered); // db::generate(ctx).await?; } Ok(()) diff --git a/src/generator/modules/discovery/table.rs b/src/generator/modules/discovery/table.rs new file mode 100644 index 0000000..ed8e381 --- /dev/null +++ b/src/generator/modules/discovery/table.rs @@ -0,0 +1,50 @@ +use super::column::Column; +use color_eyre::{eyre::eyre, Result}; +use sea_schema::sea_query::{self, TableCreateStatement}; + +#[derive(Debug, Clone)] +pub struct Table { + pub name: String, + pub columns: Vec, +} + +impl Table { + pub fn new(statement: TableCreateStatement) -> Result
{ + let table_name = match statement.get_table_name() { + Some(table_ref) => match table_ref { + sea_query::TableRef::Table(t) + | sea_query::TableRef::SchemaTable(_, t) + | sea_query::TableRef::DatabaseSchemaTable(_, _, t) + | sea_query::TableRef::TableAlias(t, _) + | sea_query::TableRef::SchemaTableAlias(_, t, _) + | sea_query::TableRef::DatabaseSchemaTableAlias(_, _, t, _) => t.to_string(), + _ => unimplemented!(), + }, + None => return Err(eyre!("Table name not found")), + }; + tracing::debug!(?table_name); + let columns_raw = statement.get_columns(); + let indexes = statement.get_indexes(); + for column in columns_raw { + tracing::debug!(?column); + } + for index in indexes { + tracing::debug!(?index); + } + let columns = columns_raw + .iter() + .map(|column| { + let name = column.get_column_name(); + let index = indexes + .iter() + .find(|index| index.get_index_spec().get_column_names().contains(&name)); + Column::new(column.clone(), index.cloned()) + }) + .collect::>>()?; + tracing::debug!(?columns); + Ok(Table { + columns, + name: table_name, + }) + } +}