meerschaum.utils.sql
Flavor-specific SQL tools.
1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# vim:fenc=utf-8 4 5""" 6Flavor-specific SQL tools. 7""" 8 9from __future__ import annotations 10 11from datetime import datetime, timezone, timedelta 12import meerschaum as mrsm 13from meerschaum.utils.typing import Optional, Dict, Any, Union, List, Iterable, Tuple 14### Preserve legacy imports. 15from meerschaum.utils.dtypes.sql import ( 16 DB_TO_PD_DTYPES, 17 PD_TO_DB_DTYPES_FLAVORS, 18 get_pd_type_from_db_type as get_pd_type, 19 get_db_type_from_pd_type as get_db_type, 20 TIMEZONE_NAIVE_FLAVORS, 21) 22from meerschaum.utils.warnings import warn 23from meerschaum.utils.debug import dprint 24 25test_queries = { 26 'default' : 'SELECT 1', 27 'oracle' : 'SELECT 1 FROM DUAL', 28 'informix' : 'SELECT COUNT(*) FROM systables', 29 'hsqldb' : 'SELECT 1 FROM INFORMATION_SCHEMA.SYSTEM_USERS', 30} 31### `table_name` is the escaped name of the table. 32### `table` is the unescaped name of the table. 33exists_queries = { 34 'default': "SELECT COUNT(*) FROM {table_name} WHERE 1 = 0", 35} 36version_queries = { 37 'default': "SELECT VERSION() AS {version_name}", 38 'sqlite': "SELECT SQLITE_VERSION() AS {version_name}", 39 'mssql': "SELECT @@version", 40 'oracle': "SELECT version from PRODUCT_COMPONENT_VERSION WHERE rownum = 1", 41} 42SKIP_IF_EXISTS_FLAVORS = {'mssql', 'oracle'} 43DROP_IF_EXISTS_FLAVORS = { 44 'timescaledb', 'postgresql', 'postgis', 'citus', 'mssql', 'mysql', 'mariadb', 'sqlite', 45} 46DROP_INDEX_IF_EXISTS_FLAVORS = { 47 'mssql', 'timescaledb', 'postgresql', 'postgis', 'sqlite', 'citus', 48} 49SKIP_AUTO_INCREMENT_FLAVORS = {'citus', 'duckdb'} 50COALESCE_UNIQUE_INDEX_FLAVORS = {'timescaledb', 'postgresql', 'postgis', 'citus'} 51UPDATE_QUERIES = { 52 'default': """ 53 UPDATE {target_table_name} AS f 54 {sets_subquery_none} 55 FROM {target_table_name} AS t 56 INNER JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p 57 ON 58 {and_subquery_t} 59 WHERE 60 {and_subquery_f} 61 AND 62 {date_bounds_subquery} 63 """, 64 'timescaledb-upsert': """ 65 INSERT INTO {target_table_name} ({patch_cols_str}) 66 SELECT {patch_cols_str} 67 FROM {patch_table_name} 68 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 69 """, 70 'postgresql-upsert': """ 71 INSERT INTO {target_table_name} ({patch_cols_str}) 72 SELECT {patch_cols_str} 73 FROM {patch_table_name} 74 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 75 """, 76 'postgis-upsert': """ 77 INSERT INTO {target_table_name} ({patch_cols_str}) 78 SELECT {patch_cols_str} 79 FROM {patch_table_name} 80 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 81 """, 82 'citus-upsert': """ 83 INSERT INTO {target_table_name} ({patch_cols_str}) 84 SELECT {patch_cols_str} 85 FROM {patch_table_name} 86 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 87 """, 88 'cockroachdb-upsert': """ 89 INSERT INTO {target_table_name} ({patch_cols_str}) 90 SELECT {patch_cols_str} 91 FROM {patch_table_name} 92 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 93 """, 94 'mysql': """ 95 UPDATE {target_table_name} AS f 96 JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p 97 ON 98 {and_subquery_f} 99 {sets_subquery_f} 100 WHERE 101 {date_bounds_subquery} 102 """, 103 'mysql-upsert': """ 104 INSERT {ignore}INTO {target_table_name} ({patch_cols_str}) 105 SELECT {patch_cols_str} 106 FROM {patch_table_name} 107 {on_duplicate_key_update} 108 {cols_equal_values} 109 """, 110 'mariadb': """ 111 UPDATE {target_table_name} AS f 112 JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p 113 ON 114 {and_subquery_f} 115 {sets_subquery_f} 116 WHERE 117 {date_bounds_subquery} 118 """, 119 'mariadb-upsert': """ 120 INSERT {ignore}INTO {target_table_name} ({patch_cols_str}) 121 SELECT {patch_cols_str} 122 FROM {patch_table_name} 123 {on_duplicate_key_update} 124 {cols_equal_values} 125 """, 126 'mssql': """ 127 {with_temp_date_bounds} 128 MERGE {target_table_name} f 129 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 130 ON 131 {and_subquery_f} 132 AND 133 {date_bounds_subquery} 134 WHEN MATCHED THEN 135 UPDATE 136 {sets_subquery_none}; 137 """, 138 'mssql-upsert': [ 139 "{identity_insert_on}", 140 """ 141 {with_temp_date_bounds} 142 MERGE {target_table_name} f 143 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 144 ON 145 {and_subquery_f} 146 AND 147 {date_bounds_subquery}{when_matched_update_sets_subquery_none} 148 WHEN NOT MATCHED THEN 149 INSERT ({patch_cols_str}) 150 VALUES ({patch_cols_prefixed_str}); 151 """, 152 "{identity_insert_off}", 153 ], 154 'oracle': """ 155 MERGE INTO {target_table_name} f 156 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 157 ON ( 158 {and_subquery_f} 159 AND 160 {date_bounds_subquery} 161 ) 162 WHEN MATCHED THEN 163 UPDATE 164 {sets_subquery_none} 165 """, 166 'oracle-upsert': """ 167 MERGE INTO {target_table_name} f 168 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 169 ON ( 170 {and_subquery_f} 171 AND 172 {date_bounds_subquery} 173 ){when_matched_update_sets_subquery_none} 174 WHEN NOT MATCHED THEN 175 INSERT ({patch_cols_str}) 176 VALUES ({patch_cols_prefixed_str}) 177 """, 178 'sqlite-upsert': """ 179 INSERT INTO {target_table_name} ({patch_cols_str}) 180 SELECT {patch_cols_str} 181 FROM {patch_table_name} 182 WHERE true 183 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 184 """, 185 'sqlite_delete_insert': [ 186 """ 187 DELETE FROM {target_table_name} AS f 188 WHERE ROWID IN ( 189 SELECT t.ROWID 190 FROM {target_table_name} AS t 191 INNER JOIN (SELECT * FROM {patch_table_name}) AS p 192 ON {and_subquery_t} 193 ); 194 """, 195 """ 196 INSERT INTO {target_table_name} AS f 197 SELECT {patch_cols_str} FROM {patch_table_name} AS p 198 """, 199 ], 200} 201columns_types_queries = { 202 'default': """ 203 SELECT 204 table_catalog AS database, 205 table_schema AS schema, 206 table_name AS table, 207 column_name AS column, 208 data_type AS type, 209 numeric_precision, 210 numeric_scale 211 FROM information_schema.columns 212 WHERE table_name IN ('{table}', '{table_trunc}') 213 """, 214 'sqlite': """ 215 SELECT 216 '' "database", 217 '' "schema", 218 m.name "table", 219 p.name "column", 220 p.type "type" 221 FROM sqlite_master m 222 LEFT OUTER JOIN pragma_table_info(m.name) p 223 ON m.name <> p.name 224 WHERE m.type = 'table' 225 AND m.name IN ('{table}', '{table_trunc}') 226 """, 227 'mssql': """ 228 SELECT 229 TABLE_CATALOG AS [database], 230 TABLE_SCHEMA AS [schema], 231 TABLE_NAME AS [table], 232 COLUMN_NAME AS [column], 233 DATA_TYPE AS [type], 234 NUMERIC_PRECISION AS [numeric_precision], 235 NUMERIC_SCALE AS [numeric_scale] 236 FROM {db_prefix}INFORMATION_SCHEMA.COLUMNS 237 WHERE TABLE_NAME IN ( 238 '{table}', 239 '{table_trunc}' 240 ) 241 242 """, 243 'mysql': """ 244 SELECT 245 TABLE_SCHEMA `database`, 246 TABLE_SCHEMA `schema`, 247 TABLE_NAME `table`, 248 COLUMN_NAME `column`, 249 DATA_TYPE `type`, 250 NUMERIC_PRECISION `numeric_precision`, 251 NUMERIC_SCALE `numeric_scale` 252 FROM INFORMATION_SCHEMA.COLUMNS 253 WHERE TABLE_NAME IN ('{table}', '{table_trunc}') 254 """, 255 'mariadb': """ 256 SELECT 257 TABLE_SCHEMA `database`, 258 TABLE_SCHEMA `schema`, 259 TABLE_NAME `table`, 260 COLUMN_NAME `column`, 261 DATA_TYPE `type`, 262 NUMERIC_PRECISION `numeric_precision`, 263 NUMERIC_SCALE `numeric_scale` 264 FROM INFORMATION_SCHEMA.COLUMNS 265 WHERE TABLE_NAME IN ('{table}', '{table_trunc}') 266 """, 267 'oracle': """ 268 SELECT 269 NULL AS "database", 270 NULL AS "schema", 271 TABLE_NAME AS "table", 272 COLUMN_NAME AS "column", 273 DATA_TYPE AS "type", 274 DATA_PRECISION AS "numeric_precision", 275 DATA_SCALE AS "numeric_scale" 276 FROM all_tab_columns 277 WHERE TABLE_NAME IN ( 278 '{table}', 279 '{table_trunc}', 280 '{table_lower}', 281 '{table_lower_trunc}', 282 '{table_upper}', 283 '{table_upper_trunc}' 284 ) 285 """, 286} 287hypertable_queries = { 288 'timescaledb': 'SELECT hypertable_size(\'{table_name}\')', 289 'citus': 'SELECT citus_table_size(\'{table_name}\')', 290} 291columns_indices_queries = { 292 'default': """ 293 SELECT 294 current_database() AS "database", 295 n.nspname AS "schema", 296 t.relname AS "table", 297 c.column_name AS "column", 298 i.relname AS "index", 299 CASE WHEN con.contype = 'p' THEN 'PRIMARY KEY' ELSE 'INDEX' END AS "index_type" 300 FROM pg_class t 301 INNER JOIN pg_index AS ix 302 ON t.oid = ix.indrelid 303 INNER JOIN pg_class AS i 304 ON i.oid = ix.indexrelid 305 INNER JOIN pg_namespace AS n 306 ON n.oid = t.relnamespace 307 INNER JOIN pg_attribute AS a 308 ON a.attnum = ANY(ix.indkey) 309 AND a.attrelid = t.oid 310 INNER JOIN information_schema.columns AS c 311 ON c.column_name = a.attname 312 AND c.table_name = t.relname 313 AND c.table_schema = n.nspname 314 LEFT JOIN pg_constraint AS con 315 ON con.conindid = i.oid 316 AND con.contype = 'p' 317 WHERE 318 t.relname IN ('{table}', '{table_trunc}') 319 AND n.nspname = '{schema}' 320 """, 321 'sqlite': """ 322 WITH indexed_columns AS ( 323 SELECT 324 '{table}' AS table_name, 325 pi.name AS column_name, 326 i.name AS index_name, 327 'INDEX' AS index_type 328 FROM 329 sqlite_master AS i, 330 pragma_index_info(i.name) AS pi 331 WHERE 332 i.type = 'index' 333 AND i.tbl_name = '{table}' 334 ), 335 primary_key_columns AS ( 336 SELECT 337 '{table}' AS table_name, 338 ti.name AS column_name, 339 'PRIMARY_KEY' AS index_name, 340 'PRIMARY KEY' AS index_type 341 FROM 342 pragma_table_info('{table}') AS ti 343 WHERE 344 ti.pk > 0 345 ) 346 SELECT 347 NULL AS "database", 348 NULL AS "schema", 349 "table_name" AS "table", 350 "column_name" AS "column", 351 "index_name" AS "index", 352 "index_type" 353 FROM indexed_columns 354 UNION ALL 355 SELECT 356 NULL AS "database", 357 NULL AS "schema", 358 table_name AS "table", 359 column_name AS "column", 360 index_name AS "index", 361 index_type 362 FROM primary_key_columns 363 """, 364 'mssql': """ 365 SELECT 366 NULL AS [database], 367 s.name AS [schema], 368 t.name AS [table], 369 c.name AS [column], 370 i.name AS [index], 371 CASE 372 WHEN kc.type = 'PK' THEN 'PRIMARY KEY' 373 ELSE 'INDEX' 374 END AS [index_type], 375 CASE 376 WHEN i.type = 1 THEN CAST(1 AS BIT) 377 ELSE CAST(0 AS BIT) 378 END AS [clustered] 379 FROM 380 sys.schemas s 381 INNER JOIN sys.tables t 382 ON s.schema_id = t.schema_id 383 INNER JOIN sys.indexes i 384 ON t.object_id = i.object_id 385 INNER JOIN sys.index_columns ic 386 ON i.object_id = ic.object_id 387 AND i.index_id = ic.index_id 388 INNER JOIN sys.columns c 389 ON ic.object_id = c.object_id 390 AND ic.column_id = c.column_id 391 LEFT JOIN sys.key_constraints kc 392 ON kc.parent_object_id = i.object_id 393 AND kc.type = 'PK' 394 AND kc.name = i.name 395 WHERE 396 t.name IN ('{table}', '{table_trunc}') 397 AND s.name = '{schema}' 398 AND i.type IN (1, 2) 399 """, 400 'oracle': """ 401 SELECT 402 NULL AS "database", 403 ic.table_owner AS "schema", 404 ic.table_name AS "table", 405 ic.column_name AS "column", 406 i.index_name AS "index", 407 CASE 408 WHEN c.constraint_type = 'P' THEN 'PRIMARY KEY' 409 WHEN i.uniqueness = 'UNIQUE' THEN 'UNIQUE INDEX' 410 ELSE 'INDEX' 411 END AS index_type 412 FROM 413 all_ind_columns ic 414 INNER JOIN all_indexes i 415 ON ic.index_name = i.index_name 416 AND ic.table_owner = i.owner 417 LEFT JOIN all_constraints c 418 ON i.index_name = c.constraint_name 419 AND i.table_owner = c.owner 420 AND c.constraint_type = 'P' 421 WHERE ic.table_name IN ( 422 '{table}', 423 '{table_trunc}', 424 '{table_upper}', 425 '{table_upper_trunc}' 426 ) 427 """, 428 'mysql': """ 429 SELECT 430 TABLE_SCHEMA AS `database`, 431 TABLE_SCHEMA AS `schema`, 432 TABLE_NAME AS `table`, 433 COLUMN_NAME AS `column`, 434 INDEX_NAME AS `index`, 435 CASE 436 WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY' 437 ELSE 'INDEX' 438 END AS `index_type` 439 FROM 440 information_schema.STATISTICS 441 WHERE 442 TABLE_NAME IN ('{table}', '{table_trunc}') 443 """, 444 'mariadb': """ 445 SELECT 446 TABLE_SCHEMA AS `database`, 447 TABLE_SCHEMA AS `schema`, 448 TABLE_NAME AS `table`, 449 COLUMN_NAME AS `column`, 450 INDEX_NAME AS `index`, 451 CASE 452 WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY' 453 ELSE 'INDEX' 454 END AS `index_type` 455 FROM 456 information_schema.STATISTICS 457 WHERE 458 TABLE_NAME IN ('{table}', '{table_trunc}') 459 """, 460} 461reset_autoincrement_queries: Dict[str, Union[str, List[str]]] = { 462 'default': """ 463 SELECT SETVAL(pg_get_serial_sequence('{table_name}', '{column}'), {val}) 464 FROM {table_name} 465 """, 466 'mssql': """ 467 DBCC CHECKIDENT ('{table_name}', RESEED, {val}) 468 """, 469 'mysql': """ 470 ALTER TABLE {table_name} AUTO_INCREMENT = {val} 471 """, 472 'mariadb': """ 473 ALTER TABLE {table_name} AUTO_INCREMENT = {val} 474 """, 475 'sqlite': """ 476 UPDATE sqlite_sequence 477 SET seq = {val} 478 WHERE name = '{table}' 479 """, 480 'oracle': ( 481 "ALTER TABLE {table_name} MODIFY {column_name} " 482 "GENERATED BY DEFAULT ON NULL AS IDENTITY (START WITH {val_plus_1})" 483 ), 484} 485table_wrappers = { 486 'default' : ('"', '"'), 487 'timescaledb': ('"', '"'), 488 'citus' : ('"', '"'), 489 'duckdb' : ('"', '"'), 490 'postgresql' : ('"', '"'), 491 'postgis' : ('"', '"'), 492 'sqlite' : ('"', '"'), 493 'mysql' : ('`', '`'), 494 'mariadb' : ('`', '`'), 495 'mssql' : ('[', ']'), 496 'cockroachdb': ('"', '"'), 497 'oracle' : ('"', '"'), 498} 499max_name_lens = { 500 'default' : 64, 501 'mssql' : 128, 502 'oracle' : 30, 503 'postgresql' : 64, 504 'postgis' : 64, 505 'timescaledb': 64, 506 'citus' : 64, 507 'cockroachdb': 64, 508 'sqlite' : 1024, ### Probably more, but 1024 seems more than reasonable. 509 'mysql' : 64, 510 'mariadb' : 64, 511} 512json_flavors = {'postgresql', 'postgis', 'timescaledb', 'citus', 'cockroachdb'} 513NO_SCHEMA_FLAVORS = {'oracle', 'sqlite', 'mysql', 'mariadb', 'duckdb'} 514DEFAULT_SCHEMA_FLAVORS = { 515 'postgresql': 'public', 516 'postgis': 'public', 517 'timescaledb': 'public', 518 'citus': 'public', 519 'cockroachdb': 'public', 520 'mysql': 'mysql', 521 'mariadb': 'mysql', 522 'mssql': 'dbo', 523} 524OMIT_NULLSFIRST_FLAVORS = {'mariadb', 'mysql', 'mssql'} 525 526SINGLE_ALTER_TABLE_FLAVORS = {'duckdb', 'sqlite', 'mssql', 'oracle'} 527NO_CTE_FLAVORS = {'mysql', 'mariadb'} 528NO_SELECT_INTO_FLAVORS = {'sqlite', 'oracle', 'mysql', 'mariadb', 'duckdb'} 529 530 531def clean(substring: str) -> None: 532 """ 533 Ensure a substring is clean enough to be inserted into a SQL query. 534 Raises an exception when banned words are used. 535 """ 536 from meerschaum.utils.warnings import error 537 banned_symbols = [';', '--', 'drop ',] 538 for symbol in banned_symbols: 539 if symbol in str(substring).lower(): 540 error(f"Invalid string: '{substring}'") 541 542 543def dateadd_str( 544 flavor: str = 'postgresql', 545 datepart: str = 'day', 546 number: Union[int, float] = 0, 547 begin: Union[str, datetime, int] = 'now', 548 db_type: Optional[str] = None, 549) -> str: 550 """ 551 Generate a `DATEADD` clause depending on database flavor. 552 553 Parameters 554 ---------- 555 flavor: str, default `'postgresql'` 556 SQL database flavor, e.g. `'postgresql'`, `'sqlite'`. 557 558 Currently supported flavors: 559 560 - `'postgresql'` 561 - `'postgis'` 562 - `'timescaledb'` 563 - `'citus'` 564 - `'cockroachdb'` 565 - `'duckdb'` 566 - `'mssql'` 567 - `'mysql'` 568 - `'mariadb'` 569 - `'sqlite'` 570 - `'oracle'` 571 572 datepart: str, default `'day'` 573 Which part of the date to modify. Supported values: 574 575 - `'year'` 576 - `'month'` 577 - `'day'` 578 - `'hour'` 579 - `'minute'` 580 - `'second'` 581 582 number: Union[int, float], default `0` 583 How many units to add to the date part. 584 585 begin: Union[str, datetime], default `'now'` 586 Base datetime to which to add dateparts. 587 588 db_type: Optional[str], default None 589 If provided, cast the datetime string as the type. 590 Otherwise, infer this from the input datetime value. 591 592 Returns 593 ------- 594 The appropriate `DATEADD` string for the corresponding database flavor. 595 596 Examples 597 -------- 598 >>> dateadd_str( 599 ... flavor='mssql', 600 ... begin=datetime(2022, 1, 1, 0, 0), 601 ... number=1, 602 ... ) 603 "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))" 604 >>> dateadd_str( 605 ... flavor='postgresql', 606 ... begin=datetime(2022, 1, 1, 0, 0), 607 ... number=1, 608 ... ) 609 "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'" 610 611 """ 612 from meerschaum.utils.packages import attempt_import 613 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, get_pd_type_from_db_type 614 dateutil_parser = attempt_import('dateutil.parser') 615 if 'int' in str(type(begin)).lower(): 616 num_str = str(begin) 617 if number is not None and number != 0: 618 num_str += ( 619 f' + {number}' 620 if number > 0 621 else f" - {number * -1}" 622 ) 623 return num_str 624 if not begin: 625 return '' 626 627 _original_begin = begin 628 begin_time = None 629 ### Sanity check: make sure `begin` is a valid datetime before we inject anything. 630 if not isinstance(begin, datetime): 631 try: 632 begin_time = dateutil_parser.parse(begin) 633 except Exception: 634 begin_time = None 635 else: 636 begin_time = begin 637 638 ### Unable to parse into a datetime. 639 if begin_time is None: 640 ### Throw an error if banned symbols are included in the `begin` string. 641 clean(str(begin)) 642 ### If begin is a valid datetime, wrap it in quotes. 643 else: 644 if isinstance(begin, datetime) and begin.tzinfo is not None: 645 begin = begin.astimezone(timezone.utc) 646 begin = ( 647 f"'{begin.replace(tzinfo=None)}'" 648 if isinstance(begin, datetime) and flavor in TIMEZONE_NAIVE_FLAVORS 649 else f"'{begin}'" 650 ) 651 652 dt_is_utc = ( 653 begin_time.tzinfo is not None 654 if begin_time is not None 655 else ('+' in str(begin) or '-' in str(begin).split(':', maxsplit=1)[-1]) 656 ) 657 if db_type: 658 db_type_is_utc = 'utc' in get_pd_type_from_db_type(db_type).lower() 659 dt_is_utc = dt_is_utc or db_type_is_utc 660 db_type = db_type or get_db_type_from_pd_type( 661 ('datetime64[ns, UTC]' if dt_is_utc else 'datetime64[ns]'), 662 flavor=flavor, 663 ) 664 665 da = "" 666 if flavor in ('postgresql', 'postgis', 'timescaledb', 'cockroachdb', 'citus'): 667 begin = ( 668 f"CAST({begin} AS {db_type})" if begin != 'now' 669 else f"CAST(NOW() AT TIME ZONE 'utc' AS {db_type})" 670 ) 671 if dt_is_utc: 672 begin += " AT TIME ZONE 'UTC'" 673 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 674 675 elif flavor == 'duckdb': 676 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'NOW()' 677 if dt_is_utc: 678 begin += " AT TIME ZONE 'UTC'" 679 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 680 681 elif flavor in ('mssql',): 682 if begin_time and begin_time.microsecond != 0 and not dt_is_utc: 683 begin = begin[:-4] + "'" 684 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'GETUTCDATE()' 685 if dt_is_utc: 686 begin += " AT TIME ZONE 'UTC'" 687 da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin 688 689 elif flavor in ('mysql', 'mariadb'): 690 begin = ( 691 f"CAST({begin} AS DATETIME(6))" 692 if begin != 'now' 693 else 'UTC_TIMESTAMP(6)' 694 ) 695 da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin) 696 697 elif flavor == 'sqlite': 698 da = f"datetime({begin}, '{number} {datepart}')" 699 700 elif flavor == 'oracle': 701 if begin == 'now': 702 begin = str( 703 datetime.now(timezone.utc).replace(tzinfo=None).strftime(r'%Y:%m:%d %M:%S.%f') 704 ) 705 elif begin_time: 706 begin = str(begin_time.strftime(r'%Y-%m-%d %H:%M:%S.%f')) 707 dt_format = 'YYYY-MM-DD HH24:MI:SS.FF' 708 _begin = f"'{begin}'" if begin_time else begin 709 da = ( 710 (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin) 711 + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "") 712 ) 713 return da 714 715 716def test_connection( 717 self, 718 **kw: Any 719) -> Union[bool, None]: 720 """ 721 Test if a successful connection to the database may be made. 722 723 Parameters 724 ---------- 725 **kw: 726 The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`. 727 728 Returns 729 ------- 730 `True` if a connection is made, otherwise `False` or `None` in case of failure. 731 732 """ 733 import warnings 734 from meerschaum.connectors.poll import retry_connect 735 _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self} 736 _default_kw.update(kw) 737 with warnings.catch_warnings(): 738 warnings.filterwarnings('ignore', 'Could not') 739 try: 740 return retry_connect(**_default_kw) 741 except Exception: 742 return False 743 744 745def get_distinct_col_count( 746 col: str, 747 query: str, 748 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 749 debug: bool = False 750) -> Optional[int]: 751 """ 752 Returns the number of distinct items in a column of a SQL query. 753 754 Parameters 755 ---------- 756 col: str: 757 The column in the query to count. 758 759 query: str: 760 The SQL query to count from. 761 762 connector: Optional[mrsm.connectors.sql.SQLConnector], default None: 763 The SQLConnector to execute the query. 764 765 debug: bool, default False: 766 Verbosity toggle. 767 768 Returns 769 ------- 770 An `int` of the number of columns in the query or `None` if the query fails. 771 772 """ 773 if connector is None: 774 connector = mrsm.get_connector('sql') 775 776 _col_name = sql_item_name(col, connector.flavor, None) 777 778 _meta_query = ( 779 f""" 780 WITH src AS ( {query} ), 781 dist AS ( SELECT DISTINCT {_col_name} FROM src ) 782 SELECT COUNT(*) FROM dist""" 783 ) if connector.flavor not in ('mysql', 'mariadb') else ( 784 f""" 785 SELECT COUNT(*) 786 FROM ( 787 SELECT DISTINCT {_col_name} 788 FROM ({query}) AS src 789 ) AS dist""" 790 ) 791 792 result = connector.value(_meta_query, debug=debug) 793 try: 794 return int(result) 795 except Exception: 796 return None 797 798 799def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str: 800 """ 801 Parse SQL items depending on the flavor. 802 803 Parameters 804 ---------- 805 item: str 806 The database item (table, view, etc.) in need of quotes. 807 808 flavor: str 809 The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.). 810 811 schema: Optional[str], default None 812 If provided, prefix the table name with the schema. 813 814 Returns 815 ------- 816 A `str` which contains the input `item` wrapped in the corresponding escape characters. 817 818 Examples 819 -------- 820 >>> sql_item_name('table', 'sqlite') 821 '"table"' 822 >>> sql_item_name('table', 'mssql') 823 "[table]" 824 >>> sql_item_name('table', 'postgresql', schema='abc') 825 '"abc"."table"' 826 827 """ 828 truncated_item = truncate_item_name(str(item), flavor) 829 if flavor == 'oracle': 830 truncated_item = pg_capital(truncated_item, quote_capitals=True) 831 ### NOTE: System-reserved words must be quoted. 832 if truncated_item.lower() in ( 833 'float', 'varchar', 'nvarchar', 'clob', 834 'boolean', 'integer', 'table', 'row', 835 ): 836 wrappers = ('"', '"') 837 else: 838 wrappers = ('', '') 839 else: 840 wrappers = table_wrappers.get(flavor, table_wrappers['default']) 841 842 ### NOTE: SQLite does not support schemas. 843 if flavor == 'sqlite': 844 schema = None 845 elif flavor == 'mssql' and str(item).startswith('#'): 846 schema = None 847 848 schema_prefix = ( 849 (wrappers[0] + schema + wrappers[1] + '.') 850 if schema is not None 851 else '' 852 ) 853 854 return schema_prefix + wrappers[0] + truncated_item + wrappers[1] 855 856 857def pg_capital(s: str, quote_capitals: bool = True) -> str: 858 """ 859 If string contains a capital letter, wrap it in double quotes. 860 861 Parameters 862 ---------- 863 s: str 864 The string to be escaped. 865 866 quote_capitals: bool, default True 867 If `False`, do not quote strings with contain only a mix of capital and lower-case letters. 868 869 Returns 870 ------- 871 The input string wrapped in quotes only if it needs them. 872 873 Examples 874 -------- 875 >>> pg_capital("My Table") 876 '"My Table"' 877 >>> pg_capital('my_table') 878 'my_table' 879 880 """ 881 if s.startswith('"') and s.endswith('"'): 882 return s 883 884 s = s.replace('"', '') 885 886 needs_quotes = s.startswith('_') 887 if not needs_quotes: 888 for c in s: 889 if c == '_': 890 continue 891 892 if not c.isalnum() or (quote_capitals and c.isupper()): 893 needs_quotes = True 894 break 895 896 if needs_quotes: 897 return '"' + s + '"' 898 899 return s 900 901 902def oracle_capital(s: str) -> str: 903 """ 904 Capitalize the string of an item on an Oracle database. 905 """ 906 return s 907 908 909def truncate_item_name(item: str, flavor: str) -> str: 910 """ 911 Truncate item names to stay within the database flavor's character limit. 912 913 Parameters 914 ---------- 915 item: str 916 The database item being referenced. This string is the "canonical" name internally. 917 918 flavor: str 919 The flavor of the database on which `item` resides. 920 921 Returns 922 ------- 923 The truncated string. 924 """ 925 from meerschaum.utils.misc import truncate_string_sections 926 return truncate_string_sections( 927 item, max_len=max_name_lens.get(flavor, max_name_lens['default']) 928 ) 929 930 931def build_where( 932 params: Dict[str, Any], 933 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 934 with_where: bool = True, 935 flavor: str = 'postgresql', 936) -> str: 937 """ 938 Build the `WHERE` clause based on the input criteria. 939 940 Parameters 941 ---------- 942 params: Dict[str, Any]: 943 The keywords dictionary to convert into a WHERE clause. 944 If a value is a string which begins with an underscore, negate that value 945 (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`). 946 A value of `_None` will be interpreted as `IS NOT NULL`. 947 948 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 949 The Meerschaum SQLConnector that will be executing the query. 950 The connector is used to extract the SQL dialect. 951 952 with_where: bool, default True: 953 If `True`, include the leading `'WHERE'` string. 954 955 flavor: str, default 'postgresql' 956 If `connector` is `None`, fall back to this flavor. 957 958 Returns 959 ------- 960 A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor. 961 962 Examples 963 -------- 964 ``` 965 >>> print(build_where({'foo': [1, 2, 3]})) 966 967 WHERE 968 "foo" IN ('1', '2', '3') 969 ``` 970 """ 971 import json 972 from meerschaum.config.static import STATIC_CONFIG 973 from meerschaum.utils.warnings import warn 974 from meerschaum.utils.dtypes import value_is_null, none_if_null 975 negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix'] 976 try: 977 params_json = json.dumps(params) 978 except Exception: 979 params_json = str(params) 980 bad_words = ['drop ', '--', ';'] 981 for word in bad_words: 982 if word in params_json.lower(): 983 warn("Aborting build_where() due to possible SQL injection.") 984 return '' 985 986 query_flavor = getattr(connector, 'flavor', flavor) if connector is not None else flavor 987 where = "" 988 leading_and = "\n AND " 989 for key, value in params.items(): 990 _key = sql_item_name(key, query_flavor, None) 991 ### search across a list (i.e. IN syntax) 992 if isinstance(value, Iterable) and not isinstance(value, (dict, str)): 993 includes = [ 994 none_if_null(item) 995 for item in value 996 if not str(item).startswith(negation_prefix) 997 ] 998 null_includes = [item for item in includes if item is None] 999 not_null_includes = [item for item in includes if item is not None] 1000 excludes = [ 1001 none_if_null(str(item)[len(negation_prefix):]) 1002 for item in value 1003 if str(item).startswith(negation_prefix) 1004 ] 1005 null_excludes = [item for item in excludes if item is None] 1006 not_null_excludes = [item for item in excludes if item is not None] 1007 1008 if includes: 1009 where += f"{leading_and}(" 1010 if not_null_includes: 1011 where += f"{_key} IN (" 1012 for item in not_null_includes: 1013 quoted_item = str(item).replace("'", "''") 1014 where += f"'{quoted_item}', " 1015 where = where[:-2] + ")" 1016 if null_includes: 1017 where += ("\n OR " if not_null_includes else "") + f"{_key} IS NULL" 1018 if includes: 1019 where += ")" 1020 1021 if excludes: 1022 where += f"{leading_and}(" 1023 if not_null_excludes: 1024 where += f"{_key} NOT IN (" 1025 for item in not_null_excludes: 1026 quoted_item = str(item).replace("'", "''") 1027 where += f"'{quoted_item}', " 1028 where = where[:-2] + ")" 1029 if null_excludes: 1030 where += ("\n AND " if not_null_excludes else "") + f"{_key} IS NOT NULL" 1031 if excludes: 1032 where += ")" 1033 1034 continue 1035 1036 ### search a dictionary 1037 elif isinstance(value, dict): 1038 import json 1039 where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'") 1040 continue 1041 1042 eq_sign = '=' 1043 is_null = 'IS NULL' 1044 if value_is_null(str(value).lstrip(negation_prefix)): 1045 value = ( 1046 (negation_prefix + 'None') 1047 if str(value).startswith(negation_prefix) 1048 else None 1049 ) 1050 if str(value).startswith(negation_prefix): 1051 value = str(value)[len(negation_prefix):] 1052 eq_sign = '!=' 1053 if value_is_null(value): 1054 value = None 1055 is_null = 'IS NOT NULL' 1056 quoted_value = str(value).replace("'", "''") 1057 where += ( 1058 f"{leading_and}{_key} " 1059 + (is_null if value is None else f"{eq_sign} '{quoted_value}'") 1060 ) 1061 1062 if len(where) > 1: 1063 where = ("\nWHERE\n " if with_where else '') + where[len(leading_and):] 1064 return where 1065 1066 1067def table_exists( 1068 table: str, 1069 connector: mrsm.connectors.sql.SQLConnector, 1070 schema: Optional[str] = None, 1071 debug: bool = False, 1072) -> bool: 1073 """Check if a table exists. 1074 1075 Parameters 1076 ---------- 1077 table: str: 1078 The name of the table in question. 1079 1080 connector: mrsm.connectors.sql.SQLConnector 1081 The connector to the database which holds the table. 1082 1083 schema: Optional[str], default None 1084 Optionally specify the table schema. 1085 Defaults to `connector.schema`. 1086 1087 debug: bool, default False : 1088 Verbosity toggle. 1089 1090 Returns 1091 ------- 1092 A `bool` indicating whether or not the table exists on the database. 1093 """ 1094 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1095 schema = schema or connector.schema 1096 insp = sqlalchemy.inspect(connector.engine) 1097 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1098 return insp.has_table(truncated_table_name, schema=schema) 1099 1100 1101def get_sqlalchemy_table( 1102 table: str, 1103 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 1104 schema: Optional[str] = None, 1105 refresh: bool = False, 1106 debug: bool = False, 1107) -> Union['sqlalchemy.Table', None]: 1108 """ 1109 Construct a SQLAlchemy table from its name. 1110 1111 Parameters 1112 ---------- 1113 table: str 1114 The name of the table on the database. Does not need to be escaped. 1115 1116 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 1117 The connector to the database which holds the table. 1118 1119 schema: Optional[str], default None 1120 Specify on which schema the table resides. 1121 Defaults to the schema set in `connector`. 1122 1123 refresh: bool, default False 1124 If `True`, rebuild the cached table object. 1125 1126 debug: bool, default False: 1127 Verbosity toggle. 1128 1129 Returns 1130 ------- 1131 A `sqlalchemy.Table` object for the table. 1132 1133 """ 1134 if connector is None: 1135 from meerschaum import get_connector 1136 connector = get_connector('sql') 1137 1138 if connector.flavor == 'duckdb': 1139 return None 1140 1141 from meerschaum.connectors.sql.tables import get_tables 1142 from meerschaum.utils.packages import attempt_import 1143 from meerschaum.utils.warnings import warn 1144 if refresh: 1145 connector.metadata.clear() 1146 tables = get_tables(mrsm_instance=connector, debug=debug, create=False) 1147 sqlalchemy = attempt_import('sqlalchemy', lazy=False) 1148 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1149 table_kwargs = { 1150 'autoload_with': connector.engine, 1151 } 1152 if schema: 1153 table_kwargs['schema'] = schema 1154 1155 if refresh or truncated_table_name not in tables: 1156 try: 1157 tables[truncated_table_name] = sqlalchemy.Table( 1158 truncated_table_name, 1159 connector.metadata, 1160 **table_kwargs 1161 ) 1162 except sqlalchemy.exc.NoSuchTableError: 1163 warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.") 1164 return None 1165 return tables[truncated_table_name] 1166 1167 1168def get_table_cols_types( 1169 table: str, 1170 connectable: Union[ 1171 'mrsm.connectors.sql.SQLConnector', 1172 'sqlalchemy.orm.session.Session', 1173 'sqlalchemy.engine.base.Engine' 1174 ], 1175 flavor: Optional[str] = None, 1176 schema: Optional[str] = None, 1177 database: Optional[str] = None, 1178 debug: bool = False, 1179) -> Dict[str, str]: 1180 """ 1181 Return a dictionary mapping a table's columns to data types. 1182 This is useful for inspecting tables creating during a not-yet-committed session. 1183 1184 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1185 Use this function if you are confident the table name is unique or if you have 1186 and explicit schema. 1187 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1188 1189 Parameters 1190 ---------- 1191 table: str 1192 The name of the table (unquoted). 1193 1194 connectable: Union[ 1195 'mrsm.connectors.sql.SQLConnector', 1196 'sqlalchemy.orm.session.Session', 1197 'sqlalchemy.engine.base.Engine' 1198 ] 1199 The connection object used to fetch the columns and types. 1200 1201 flavor: Optional[str], default None 1202 The database dialect flavor to use for the query. 1203 If omitted, default to `connectable.flavor`. 1204 1205 schema: Optional[str], default None 1206 If provided, restrict the query to this schema. 1207 1208 database: Optional[str]. default None 1209 If provided, restrict the query to this database. 1210 1211 Returns 1212 ------- 1213 A dictionary mapping column names to data types. 1214 """ 1215 import textwrap 1216 from meerschaum.connectors import SQLConnector 1217 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1218 flavor = flavor or getattr(connectable, 'flavor', None) 1219 if not flavor: 1220 raise ValueError("Please provide a database flavor.") 1221 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1222 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1223 if flavor in NO_SCHEMA_FLAVORS: 1224 schema = None 1225 if schema is None: 1226 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1227 if flavor in ('sqlite', 'duckdb', 'oracle'): 1228 database = None 1229 table_trunc = truncate_item_name(table, flavor=flavor) 1230 table_lower = table.lower() 1231 table_upper = table.upper() 1232 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1233 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1234 db_prefix = ( 1235 "tempdb." 1236 if flavor == 'mssql' and table.startswith('#') 1237 else "" 1238 ) 1239 1240 cols_types_query = sqlalchemy.text( 1241 textwrap.dedent(columns_types_queries.get( 1242 flavor, 1243 columns_types_queries['default'] 1244 ).format( 1245 table=table, 1246 table_trunc=table_trunc, 1247 table_lower=table_lower, 1248 table_lower_trunc=table_lower_trunc, 1249 table_upper=table_upper, 1250 table_upper_trunc=table_upper_trunc, 1251 db_prefix=db_prefix, 1252 )).lstrip().rstrip() 1253 ) 1254 1255 cols = ['database', 'schema', 'table', 'column', 'type', 'numeric_precision', 'numeric_scale'] 1256 result_cols_ix = dict(enumerate(cols)) 1257 1258 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1259 if not debug_kwargs and debug: 1260 dprint(cols_types_query) 1261 1262 try: 1263 result_rows = ( 1264 [ 1265 row 1266 for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall() 1267 ] 1268 if flavor != 'duckdb' 1269 else [ 1270 tuple([doc[col] for col in cols]) 1271 for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records') 1272 ] 1273 ) 1274 cols_types_docs = [ 1275 { 1276 result_cols_ix[i]: val 1277 for i, val in enumerate(row) 1278 } 1279 for row in result_rows 1280 ] 1281 cols_types_docs_filtered = [ 1282 doc 1283 for doc in cols_types_docs 1284 if ( 1285 ( 1286 not schema 1287 or doc['schema'] == schema 1288 ) 1289 and 1290 ( 1291 not database 1292 or doc['database'] == database 1293 ) 1294 ) 1295 ] 1296 1297 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1298 if cols_types_docs and not cols_types_docs_filtered: 1299 cols_types_docs_filtered = cols_types_docs 1300 1301 ### NOTE: Check for PostGIS GEOMETRY columns. 1302 geometry_cols_types = {} 1303 user_defined_cols = [ 1304 doc 1305 for doc in cols_types_docs_filtered 1306 if str(doc.get('type', None)).upper() == 'USER-DEFINED' 1307 ] 1308 if user_defined_cols: 1309 geometry_cols_types.update( 1310 get_postgis_geo_columns_types( 1311 connectable, 1312 table, 1313 schema=schema, 1314 debug=debug, 1315 ) 1316 ) 1317 1318 cols_types = { 1319 ( 1320 doc['column'] 1321 if flavor != 'oracle' else ( 1322 ( 1323 doc['column'].lower() 1324 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1325 else doc['column'] 1326 ) 1327 ) 1328 ): doc['type'].upper() + ( 1329 f'({precision},{scale})' 1330 if ( 1331 (precision := doc.get('numeric_precision', None)) 1332 and 1333 (scale := doc.get('numeric_scale', None)) 1334 ) 1335 else '' 1336 ) 1337 for doc in cols_types_docs_filtered 1338 } 1339 cols_types.update(geometry_cols_types) 1340 return cols_types 1341 except Exception as e: 1342 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1343 return {} 1344 1345 1346def get_table_cols_indices( 1347 table: str, 1348 connectable: Union[ 1349 'mrsm.connectors.sql.SQLConnector', 1350 'sqlalchemy.orm.session.Session', 1351 'sqlalchemy.engine.base.Engine' 1352 ], 1353 flavor: Optional[str] = None, 1354 schema: Optional[str] = None, 1355 database: Optional[str] = None, 1356 debug: bool = False, 1357) -> Dict[str, List[str]]: 1358 """ 1359 Return a dictionary mapping a table's columns to lists of indices. 1360 This is useful for inspecting tables creating during a not-yet-committed session. 1361 1362 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1363 Use this function if you are confident the table name is unique or if you have 1364 and explicit schema. 1365 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1366 1367 Parameters 1368 ---------- 1369 table: str 1370 The name of the table (unquoted). 1371 1372 connectable: Union[ 1373 'mrsm.connectors.sql.SQLConnector', 1374 'sqlalchemy.orm.session.Session', 1375 'sqlalchemy.engine.base.Engine' 1376 ] 1377 The connection object used to fetch the columns and types. 1378 1379 flavor: Optional[str], default None 1380 The database dialect flavor to use for the query. 1381 If omitted, default to `connectable.flavor`. 1382 1383 schema: Optional[str], default None 1384 If provided, restrict the query to this schema. 1385 1386 database: Optional[str]. default None 1387 If provided, restrict the query to this database. 1388 1389 Returns 1390 ------- 1391 A dictionary mapping column names to a list of indices. 1392 """ 1393 import textwrap 1394 from collections import defaultdict 1395 from meerschaum.connectors import SQLConnector 1396 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1397 flavor = flavor or getattr(connectable, 'flavor', None) 1398 if not flavor: 1399 raise ValueError("Please provide a database flavor.") 1400 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1401 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1402 if flavor in NO_SCHEMA_FLAVORS: 1403 schema = None 1404 if schema is None: 1405 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1406 if flavor in ('sqlite', 'duckdb', 'oracle'): 1407 database = None 1408 table_trunc = truncate_item_name(table, flavor=flavor) 1409 table_lower = table.lower() 1410 table_upper = table.upper() 1411 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1412 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1413 db_prefix = ( 1414 "tempdb." 1415 if flavor == 'mssql' and table.startswith('#') 1416 else "" 1417 ) 1418 1419 cols_indices_query = sqlalchemy.text( 1420 textwrap.dedent(columns_indices_queries.get( 1421 flavor, 1422 columns_indices_queries['default'] 1423 ).format( 1424 table=table, 1425 table_trunc=table_trunc, 1426 table_lower=table_lower, 1427 table_lower_trunc=table_lower_trunc, 1428 table_upper=table_upper, 1429 table_upper_trunc=table_upper_trunc, 1430 db_prefix=db_prefix, 1431 schema=schema, 1432 )).lstrip().rstrip() 1433 ) 1434 1435 cols = ['database', 'schema', 'table', 'column', 'index', 'index_type'] 1436 if flavor == 'mssql': 1437 cols.append('clustered') 1438 result_cols_ix = dict(enumerate(cols)) 1439 1440 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1441 if not debug_kwargs and debug: 1442 dprint(cols_indices_query) 1443 1444 try: 1445 result_rows = ( 1446 [ 1447 row 1448 for row in connectable.execute(cols_indices_query, **debug_kwargs).fetchall() 1449 ] 1450 if flavor != 'duckdb' 1451 else [ 1452 tuple([doc[col] for col in cols]) 1453 for doc in connectable.read(cols_indices_query, debug=debug).to_dict(orient='records') 1454 ] 1455 ) 1456 cols_types_docs = [ 1457 { 1458 result_cols_ix[i]: val 1459 for i, val in enumerate(row) 1460 } 1461 for row in result_rows 1462 ] 1463 cols_types_docs_filtered = [ 1464 doc 1465 for doc in cols_types_docs 1466 if ( 1467 ( 1468 not schema 1469 or doc['schema'] == schema 1470 ) 1471 and 1472 ( 1473 not database 1474 or doc['database'] == database 1475 ) 1476 ) 1477 ] 1478 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1479 if cols_types_docs and not cols_types_docs_filtered: 1480 cols_types_docs_filtered = cols_types_docs 1481 1482 cols_indices = defaultdict(lambda: []) 1483 for doc in cols_types_docs_filtered: 1484 col = ( 1485 doc['column'] 1486 if flavor != 'oracle' 1487 else ( 1488 doc['column'].lower() 1489 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1490 else doc['column'] 1491 ) 1492 ) 1493 index_doc = { 1494 'name': doc.get('index', None), 1495 'type': doc.get('index_type', None) 1496 } 1497 if flavor == 'mssql': 1498 index_doc['clustered'] = doc.get('clustered', None) 1499 cols_indices[col].append(index_doc) 1500 1501 return dict(cols_indices) 1502 except Exception as e: 1503 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1504 return {} 1505 1506 1507def get_update_queries( 1508 target: str, 1509 patch: str, 1510 connectable: Union[ 1511 mrsm.connectors.sql.SQLConnector, 1512 'sqlalchemy.orm.session.Session' 1513 ], 1514 join_cols: Iterable[str], 1515 flavor: Optional[str] = None, 1516 upsert: bool = False, 1517 datetime_col: Optional[str] = None, 1518 schema: Optional[str] = None, 1519 patch_schema: Optional[str] = None, 1520 identity_insert: bool = False, 1521 null_indices: bool = True, 1522 cast_columns: bool = True, 1523 debug: bool = False, 1524) -> List[str]: 1525 """ 1526 Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table. 1527 1528 Parameters 1529 ---------- 1530 target: str 1531 The name of the target table. 1532 1533 patch: str 1534 The name of the patch table. This should have the same shape as the target. 1535 1536 connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session] 1537 The `SQLConnector` or SQLAlchemy session which will later execute the queries. 1538 1539 join_cols: List[str] 1540 The columns to use to join the patch to the target. 1541 1542 flavor: Optional[str], default None 1543 If using a SQLAlchemy session, provide the expected database flavor. 1544 1545 upsert: bool, default False 1546 If `True`, return an upsert query rather than an update. 1547 1548 datetime_col: Optional[str], default None 1549 If provided, bound the join query using this column as the datetime index. 1550 This must be present on both tables. 1551 1552 schema: Optional[str], default None 1553 If provided, use this schema when quoting the target table. 1554 Defaults to `connector.schema`. 1555 1556 patch_schema: Optional[str], default None 1557 If provided, use this schema when quoting the patch table. 1558 Defaults to `schema`. 1559 1560 identity_insert: bool, default False 1561 If `True`, include `SET IDENTITY_INSERT` queries before and after the update queries. 1562 Only applies for MSSQL upserts. 1563 1564 null_indices: bool, default True 1565 If `False`, do not coalesce index columns before joining. 1566 1567 cast_columns: bool, default True 1568 If `False`, do not cast update columns to the target table types. 1569 1570 debug: bool, default False 1571 Verbosity toggle. 1572 1573 Returns 1574 ------- 1575 A list of query strings to perform the update operation. 1576 """ 1577 import textwrap 1578 from meerschaum.connectors import SQLConnector 1579 from meerschaum.utils.debug import dprint 1580 from meerschaum.utils.dtypes import are_dtypes_equal 1581 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES, get_pd_type_from_db_type 1582 flavor = flavor or getattr(connectable, 'flavor', None) 1583 if not flavor: 1584 raise ValueError("Provide a flavor if using a SQLAlchemy session.") 1585 if ( 1586 flavor == 'sqlite' 1587 and isinstance(connectable, SQLConnector) 1588 and connectable.db_version < '3.33.0' 1589 ): 1590 flavor = 'sqlite_delete_insert' 1591 flavor_key = (f'{flavor}-upsert' if upsert else flavor) 1592 base_queries = UPDATE_QUERIES.get( 1593 flavor_key, 1594 UPDATE_QUERIES['default'] 1595 ) 1596 if not isinstance(base_queries, list): 1597 base_queries = [base_queries] 1598 schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None) 1599 patch_schema = patch_schema or schema 1600 target_table_columns = get_table_cols_types( 1601 target, 1602 connectable, 1603 flavor=flavor, 1604 schema=schema, 1605 debug=debug, 1606 ) 1607 patch_table_columns = get_table_cols_types( 1608 patch, 1609 connectable, 1610 flavor=flavor, 1611 schema=patch_schema, 1612 debug=debug, 1613 ) 1614 1615 patch_cols_str = ', '.join( 1616 [ 1617 sql_item_name(col, flavor) 1618 for col in patch_table_columns 1619 ] 1620 ) 1621 patch_cols_prefixed_str = ', '.join( 1622 [ 1623 'p.' + sql_item_name(col, flavor) 1624 for col in patch_table_columns 1625 ] 1626 ) 1627 1628 join_cols_str = ', '.join( 1629 [ 1630 sql_item_name(col, flavor) 1631 for col in join_cols 1632 ] 1633 ) 1634 1635 value_cols = [] 1636 join_cols_types = [] 1637 if debug: 1638 dprint("target_table_columns:") 1639 mrsm.pprint(target_table_columns) 1640 for c_name, c_type in target_table_columns.items(): 1641 if c_name not in patch_table_columns: 1642 continue 1643 if flavor in DB_FLAVORS_CAST_DTYPES: 1644 c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type) 1645 ( 1646 join_cols_types 1647 if c_name in join_cols 1648 else value_cols 1649 ).append((c_name, c_type)) 1650 if debug: 1651 dprint(f"value_cols: {value_cols}") 1652 1653 if not join_cols_types: 1654 return [] 1655 if not value_cols and not upsert: 1656 return [] 1657 1658 coalesce_join_cols_str = ', '.join( 1659 [ 1660 ( 1661 ( 1662 'COALESCE(' 1663 + sql_item_name(c_name, flavor) 1664 + ', ' 1665 + get_null_replacement(c_type, flavor) 1666 + ')' 1667 ) 1668 if null_indices 1669 else sql_item_name(c_name, flavor) 1670 ) 1671 for c_name, c_type in join_cols_types 1672 ] 1673 ) 1674 1675 update_or_nothing = ('UPDATE' if value_cols else 'NOTHING') 1676 1677 def sets_subquery(l_prefix: str, r_prefix: str): 1678 if not value_cols: 1679 return '' 1680 1681 utc_value_cols = { 1682 c_name 1683 for c_name, c_type in value_cols 1684 if ('utc' in get_pd_type_from_db_type(c_type).lower()) 1685 } if flavor not in TIMEZONE_NAIVE_FLAVORS else set() 1686 1687 cast_func_cols = { 1688 c_name: ( 1689 ('', '', '') 1690 if not cast_columns or ( 1691 flavor == 'oracle' 1692 and are_dtypes_equal(get_pd_type_from_db_type(c_type), 'bytes') 1693 ) 1694 else ( 1695 ('CAST(', f" AS {c_type.replace('_', ' ')}", ')' + ( 1696 " AT TIME ZONE 'UTC'" 1697 if c_name in utc_value_cols 1698 else '' 1699 )) 1700 if flavor != 'sqlite' 1701 else ('', '', '') 1702 ) 1703 ) 1704 for c_name, c_type in value_cols 1705 } 1706 return 'SET ' + ',\n'.join([ 1707 ( 1708 l_prefix + sql_item_name(c_name, flavor, None) 1709 + ' = ' 1710 + cast_func_cols[c_name][0] 1711 + r_prefix + sql_item_name(c_name, flavor, None) 1712 + cast_func_cols[c_name][1] 1713 + cast_func_cols[c_name][2] 1714 ) for c_name, c_type in value_cols 1715 ]) 1716 1717 def and_subquery(l_prefix: str, r_prefix: str): 1718 return '\n AND\n '.join([ 1719 ( 1720 ( 1721 "COALESCE(" 1722 + l_prefix 1723 + sql_item_name(c_name, flavor, None) 1724 + ", " 1725 + get_null_replacement(c_type, flavor) 1726 + ")" 1727 + '\n =\n ' 1728 + "COALESCE(" 1729 + r_prefix 1730 + sql_item_name(c_name, flavor, None) 1731 + ", " 1732 + get_null_replacement(c_type, flavor) 1733 + ")" 1734 ) 1735 if null_indices 1736 else ( 1737 l_prefix 1738 + sql_item_name(c_name, flavor, None) 1739 + ' = ' 1740 + r_prefix 1741 + sql_item_name(c_name, flavor, None) 1742 ) 1743 ) for c_name, c_type in join_cols_types 1744 ]) 1745 1746 skip_query_val = "" 1747 target_table_name = sql_item_name(target, flavor, schema) 1748 patch_table_name = sql_item_name(patch, flavor, patch_schema) 1749 dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None 1750 date_bounds_table = patch_table_name if flavor != 'mssql' else '[date_bounds]' 1751 min_dt_col_name = f"MIN({dt_col_name})" if flavor != 'mssql' else '[Min_dt]' 1752 max_dt_col_name = f"MAX({dt_col_name})" if flavor != 'mssql' else '[Max_dt]' 1753 date_bounds_subquery = ( 1754 f"""f.{dt_col_name} >= (SELECT {min_dt_col_name} FROM {date_bounds_table}) 1755 AND 1756 f.{dt_col_name} <= (SELECT {max_dt_col_name} FROM {date_bounds_table})""" 1757 if datetime_col 1758 else "1 = 1" 1759 ) 1760 with_temp_date_bounds = f"""WITH [date_bounds] AS ( 1761 SELECT MIN({dt_col_name}) AS {min_dt_col_name}, MAX({dt_col_name}) AS {max_dt_col_name} 1762 FROM {patch_table_name} 1763 )""" if datetime_col else "" 1764 identity_insert_on = ( 1765 f"SET IDENTITY_INSERT {target_table_name} ON" 1766 if identity_insert 1767 else skip_query_val 1768 ) 1769 identity_insert_off = ( 1770 f"SET IDENTITY_INSERT {target_table_name} OFF" 1771 if identity_insert 1772 else skip_query_val 1773 ) 1774 1775 ### NOTE: MSSQL upserts must exclude the update portion if only upserting indices. 1776 when_matched_update_sets_subquery_none = "" if not value_cols else ( 1777 "\n WHEN MATCHED THEN\n" 1778 f" UPDATE {sets_subquery('', 'p.')}" 1779 ) 1780 1781 cols_equal_values = '\n,'.join( 1782 [ 1783 f"{sql_item_name(c_name, flavor)} = VALUES({sql_item_name(c_name, flavor)})" 1784 for c_name, c_type in value_cols 1785 ] 1786 ) 1787 on_duplicate_key_update = ( 1788 "ON DUPLICATE KEY UPDATE" 1789 if value_cols 1790 else "" 1791 ) 1792 ignore = "IGNORE " if not value_cols else "" 1793 1794 formatted_queries = [ 1795 textwrap.dedent(base_query.format( 1796 sets_subquery_none=sets_subquery('', 'p.'), 1797 sets_subquery_none_excluded=sets_subquery('', 'EXCLUDED.'), 1798 sets_subquery_f=sets_subquery('f.', 'p.'), 1799 and_subquery_f=and_subquery('p.', 'f.'), 1800 and_subquery_t=and_subquery('p.', 't.'), 1801 target_table_name=target_table_name, 1802 patch_table_name=patch_table_name, 1803 patch_cols_str=patch_cols_str, 1804 patch_cols_prefixed_str=patch_cols_prefixed_str, 1805 date_bounds_subquery=date_bounds_subquery, 1806 join_cols_str=join_cols_str, 1807 coalesce_join_cols_str=coalesce_join_cols_str, 1808 update_or_nothing=update_or_nothing, 1809 when_matched_update_sets_subquery_none=when_matched_update_sets_subquery_none, 1810 cols_equal_values=cols_equal_values, 1811 on_duplicate_key_update=on_duplicate_key_update, 1812 ignore=ignore, 1813 with_temp_date_bounds=with_temp_date_bounds, 1814 identity_insert_on=identity_insert_on, 1815 identity_insert_off=identity_insert_off, 1816 )).lstrip().rstrip() 1817 for base_query in base_queries 1818 ] 1819 1820 ### NOTE: Allow for skipping some queries. 1821 return [query for query in formatted_queries if query] 1822 1823 1824def get_null_replacement(typ: str, flavor: str) -> str: 1825 """ 1826 Return a value that may temporarily be used in place of NULL for this type. 1827 1828 Parameters 1829 ---------- 1830 typ: str 1831 The typ to be converted to NULL. 1832 1833 flavor: str 1834 The database flavor for which this value will be used. 1835 1836 Returns 1837 ------- 1838 A value which may stand in place of NULL for this type. 1839 `'None'` is returned if a value cannot be determined. 1840 """ 1841 from meerschaum.utils.dtypes import are_dtypes_equal 1842 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1843 if 'geometry' in typ.lower(): 1844 return '010100000000008058346FCDC100008058346FCDC1' 1845 if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'): 1846 return '-987654321' 1847 if 'bool' in typ.lower() or typ.lower() == 'bit': 1848 bool_typ = ( 1849 PD_TO_DB_DTYPES_FLAVORS 1850 .get('bool', {}) 1851 .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default']) 1852 ) 1853 if flavor in DB_FLAVORS_CAST_DTYPES: 1854 bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ) 1855 val_to_cast = ( 1856 -987654321 1857 if flavor in ('mysql', 'mariadb') 1858 else 0 1859 ) 1860 return f'CAST({val_to_cast} AS {bool_typ})' 1861 if 'time' in typ.lower() or 'date' in typ.lower(): 1862 db_type = typ if typ.isupper() else None 1863 return dateadd_str(flavor=flavor, begin='1900-01-01', db_type=db_type) 1864 if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',): 1865 return '-987654321.0' 1866 if flavor == 'oracle' and typ.lower().split('(', maxsplit=1)[0] == 'char': 1867 return "'-987654321'" 1868 if flavor == 'oracle' and typ.lower() in ('blob', 'bytes'): 1869 return '00' 1870 if typ.lower() in ('uniqueidentifier', 'guid', 'uuid'): 1871 magic_val = 'DEADBEEF-ABBA-BABE-CAFE-DECAFC0FFEE5' 1872 if flavor == 'mssql': 1873 return f"CAST('{magic_val}' AS UNIQUEIDENTIFIER)" 1874 return f"'{magic_val}'" 1875 return ('n' if flavor == 'oracle' else '') + "'-987654321'" 1876 1877 1878def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]: 1879 """ 1880 Fetch the database version if possible. 1881 """ 1882 version_name = sql_item_name('version', conn.flavor, None) 1883 version_query = version_queries.get( 1884 conn.flavor, 1885 version_queries['default'] 1886 ).format(version_name=version_name) 1887 return conn.value(version_query, debug=debug) 1888 1889 1890def get_rename_table_queries( 1891 old_table: str, 1892 new_table: str, 1893 flavor: str, 1894 schema: Optional[str] = None, 1895) -> List[str]: 1896 """ 1897 Return queries to alter a table's name. 1898 1899 Parameters 1900 ---------- 1901 old_table: str 1902 The unquoted name of the old table. 1903 1904 new_table: str 1905 The unquoted name of the new table. 1906 1907 flavor: str 1908 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1909 1910 schema: Optional[str], default None 1911 The schema on which the table resides. 1912 1913 Returns 1914 ------- 1915 A list of `ALTER TABLE` or equivalent queries for the database flavor. 1916 """ 1917 old_table_name = sql_item_name(old_table, flavor, schema) 1918 new_table_name = sql_item_name(new_table, flavor, None) 1919 tmp_table = '_tmp_rename_' + new_table 1920 tmp_table_name = sql_item_name(tmp_table, flavor, schema) 1921 if flavor == 'mssql': 1922 return [f"EXEC sp_rename '{old_table}', '{new_table}'"] 1923 1924 if_exists_str = "IF EXISTS" if flavor in DROP_IF_EXISTS_FLAVORS else "" 1925 if flavor == 'duckdb': 1926 return ( 1927 get_create_table_queries( 1928 f"SELECT * FROM {old_table_name}", 1929 tmp_table, 1930 'duckdb', 1931 schema, 1932 ) + get_create_table_queries( 1933 f"SELECT * FROM {tmp_table_name}", 1934 new_table, 1935 'duckdb', 1936 schema, 1937 ) + [ 1938 f"DROP TABLE {if_exists_str} {tmp_table_name}", 1939 f"DROP TABLE {if_exists_str} {old_table_name}", 1940 ] 1941 ) 1942 1943 return [f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}"] 1944 1945 1946def get_create_table_query( 1947 query_or_dtypes: Union[str, Dict[str, str]], 1948 new_table: str, 1949 flavor: str, 1950 schema: Optional[str] = None, 1951) -> str: 1952 """ 1953 NOTE: This function is deprecated. Use `get_create_table_queries()` instead. 1954 1955 Return a query to create a new table from a `SELECT` query. 1956 1957 Parameters 1958 ---------- 1959 query: Union[str, Dict[str, str]] 1960 The select query to use for the creation of the table. 1961 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 1962 1963 new_table: str 1964 The unquoted name of the new table. 1965 1966 flavor: str 1967 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 1968 1969 schema: Optional[str], default None 1970 The schema on which the table will reside. 1971 1972 Returns 1973 ------- 1974 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1975 """ 1976 return get_create_table_queries( 1977 query_or_dtypes, 1978 new_table, 1979 flavor, 1980 schema=schema, 1981 primary_key=None, 1982 )[0] 1983 1984 1985def get_create_table_queries( 1986 query_or_dtypes: Union[str, Dict[str, str]], 1987 new_table: str, 1988 flavor: str, 1989 schema: Optional[str] = None, 1990 primary_key: Optional[str] = None, 1991 primary_key_db_type: Optional[str] = None, 1992 autoincrement: bool = False, 1993 datetime_column: Optional[str] = None, 1994) -> List[str]: 1995 """ 1996 Return a query to create a new table from a `SELECT` query or a `dtypes` dictionary. 1997 1998 Parameters 1999 ---------- 2000 query_or_dtypes: Union[str, Dict[str, str]] 2001 The select query to use for the creation of the table. 2002 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 2003 2004 new_table: str 2005 The unquoted name of the new table. 2006 2007 flavor: str 2008 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 2009 2010 schema: Optional[str], default None 2011 The schema on which the table will reside. 2012 2013 primary_key: Optional[str], default None 2014 If provided, designate this column as the primary key in the new table. 2015 2016 primary_key_db_type: Optional[str], default None 2017 If provided, alter the primary key to this type (to set NOT NULL constraint). 2018 2019 autoincrement: bool, default False 2020 If `True` and `primary_key` is provided, create the `primary_key` column 2021 as an auto-incrementing integer column. 2022 2023 datetime_column: Optional[str], default None 2024 If provided, include this column in the primary key. 2025 Applicable to TimescaleDB only. 2026 2027 Returns 2028 ------- 2029 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 2030 """ 2031 if not isinstance(query_or_dtypes, (str, dict)): 2032 raise TypeError("`query_or_dtypes` must be a query or a dtypes dictionary.") 2033 2034 method = ( 2035 _get_create_table_query_from_cte 2036 if isinstance(query_or_dtypes, str) 2037 else _get_create_table_query_from_dtypes 2038 ) 2039 return method( 2040 query_or_dtypes, 2041 new_table, 2042 flavor, 2043 schema=schema, 2044 primary_key=primary_key, 2045 primary_key_db_type=primary_key_db_type, 2046 autoincrement=(autoincrement and flavor not in SKIP_AUTO_INCREMENT_FLAVORS), 2047 datetime_column=datetime_column, 2048 ) 2049 2050 2051def _get_create_table_query_from_dtypes( 2052 dtypes: Dict[str, str], 2053 new_table: str, 2054 flavor: str, 2055 schema: Optional[str] = None, 2056 primary_key: Optional[str] = None, 2057 primary_key_db_type: Optional[str] = None, 2058 autoincrement: bool = False, 2059 datetime_column: Optional[str] = None, 2060) -> List[str]: 2061 """ 2062 Create a new table from a `dtypes` dictionary. 2063 """ 2064 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, AUTO_INCREMENT_COLUMN_FLAVORS 2065 if not dtypes and not primary_key: 2066 raise ValueError(f"Expecting columns for table '{new_table}'.") 2067 2068 if flavor in SKIP_AUTO_INCREMENT_FLAVORS: 2069 autoincrement = False 2070 2071 cols_types = ( 2072 [ 2073 ( 2074 primary_key, 2075 get_db_type_from_pd_type(dtypes.get(primary_key, 'int') or 'int', flavor=flavor) 2076 ) 2077 ] 2078 if primary_key 2079 else [] 2080 ) + [ 2081 (col, get_db_type_from_pd_type(typ, flavor=flavor)) 2082 for col, typ in dtypes.items() 2083 if col != primary_key 2084 ] 2085 2086 table_name = sql_item_name(new_table, schema=schema, flavor=flavor) 2087 primary_key_name = sql_item_name(primary_key, flavor) if primary_key else None 2088 primary_key_constraint_name = ( 2089 sql_item_name(f'PK_{new_table}', flavor, None) 2090 if primary_key 2091 else None 2092 ) 2093 datetime_column_name = sql_item_name(datetime_column, flavor) if datetime_column else None 2094 primary_key_clustered = ( 2095 "CLUSTERED" 2096 if not datetime_column or datetime_column == primary_key 2097 else "NONCLUSTERED" 2098 ) 2099 query = f"CREATE TABLE {table_name} (" 2100 if primary_key: 2101 col_db_type = cols_types[0][1] 2102 auto_increment_str = (' ' + AUTO_INCREMENT_COLUMN_FLAVORS.get( 2103 flavor, 2104 AUTO_INCREMENT_COLUMN_FLAVORS['default'] 2105 )) if autoincrement or primary_key not in dtypes else '' 2106 col_name = sql_item_name(primary_key, flavor=flavor, schema=None) 2107 2108 if flavor == 'sqlite': 2109 query += ( 2110 f"\n {col_name} " 2111 + (f"{col_db_type}" if not auto_increment_str else 'INTEGER') 2112 + f" PRIMARY KEY{auto_increment_str} NOT NULL," 2113 ) 2114 elif flavor == 'oracle': 2115 query += f"\n {col_name} {col_db_type} {auto_increment_str} PRIMARY KEY," 2116 elif flavor == 'timescaledb' and datetime_column and datetime_column != primary_key: 2117 query += f"\n {col_name} {col_db_type}{auto_increment_str} NOT NULL," 2118 elif flavor == 'mssql': 2119 query += f"\n {col_name} {col_db_type}{auto_increment_str} NOT NULL," 2120 else: 2121 query += f"\n {col_name} {col_db_type} PRIMARY KEY{auto_increment_str} NOT NULL," 2122 2123 for col, db_type in cols_types: 2124 if col == primary_key: 2125 continue 2126 col_name = sql_item_name(col, schema=None, flavor=flavor) 2127 query += f"\n {col_name} {db_type}," 2128 if ( 2129 flavor == 'timescaledb' 2130 and datetime_column 2131 and primary_key 2132 and datetime_column != primary_key 2133 ): 2134 query += f"\n PRIMARY KEY({datetime_column_name}, {primary_key_name})," 2135 2136 if flavor == 'mssql' and primary_key: 2137 query += f"\n CONSTRAINT {primary_key_constraint_name} PRIMARY KEY {primary_key_clustered} ({primary_key_name})," 2138 2139 query = query[:-1] 2140 query += "\n)" 2141 2142 queries = [query] 2143 return queries 2144 2145 2146def _get_create_table_query_from_cte( 2147 query: str, 2148 new_table: str, 2149 flavor: str, 2150 schema: Optional[str] = None, 2151 primary_key: Optional[str] = None, 2152 primary_key_db_type: Optional[str] = None, 2153 autoincrement: bool = False, 2154 datetime_column: Optional[str] = None, 2155) -> List[str]: 2156 """ 2157 Create a new table from a CTE query. 2158 """ 2159 import textwrap 2160 create_cte = 'create_query' 2161 create_cte_name = sql_item_name(create_cte, flavor, None) 2162 new_table_name = sql_item_name(new_table, flavor, schema) 2163 primary_key_constraint_name = ( 2164 sql_item_name(f'PK_{new_table}', flavor, None) 2165 if primary_key 2166 else None 2167 ) 2168 primary_key_name = ( 2169 sql_item_name(primary_key, flavor, None) 2170 if primary_key 2171 else None 2172 ) 2173 primary_key_clustered = ( 2174 "CLUSTERED" 2175 if not datetime_column or datetime_column == primary_key 2176 else "NONCLUSTERED" 2177 ) 2178 datetime_column_name = ( 2179 sql_item_name(datetime_column, flavor) 2180 if datetime_column 2181 else None 2182 ) 2183 if flavor in ('mssql',): 2184 query = query.lstrip() 2185 if query.lower().startswith('with '): 2186 final_select_ix = query.lower().rfind('select') 2187 create_table_queries = [ 2188 ( 2189 query[:final_select_ix].rstrip() + ',\n' 2190 + f"{create_cte_name} AS (\n" 2191 + textwrap.indent(query[final_select_ix:], ' ') 2192 + "\n)\n" 2193 + f"SELECT *\nINTO {new_table_name}\nFROM {create_cte_name}" 2194 ), 2195 ] 2196 else: 2197 create_table_queries = [ 2198 ( 2199 "SELECT *\n" 2200 f"INTO {new_table_name}\n" 2201 f"FROM (\n{textwrap.indent(query, ' ')}\n) AS {create_cte_name}" 2202 ), 2203 ] 2204 2205 alter_type_queries = [] 2206 if primary_key_db_type: 2207 alter_type_queries.extend([ 2208 ( 2209 f"ALTER TABLE {new_table_name}\n" 2210 f"ALTER COLUMN {primary_key_name} {primary_key_db_type} NOT NULL" 2211 ), 2212 ]) 2213 alter_type_queries.extend([ 2214 ( 2215 f"ALTER TABLE {new_table_name}\n" 2216 f"ADD CONSTRAINT {primary_key_constraint_name} " 2217 f"PRIMARY KEY {primary_key_clustered} ({primary_key_name})" 2218 ), 2219 ]) 2220 elif flavor in (None,): 2221 create_table_queries = [ 2222 ( 2223 f"WITH {create_cte_name} AS (\n{textwrap.index(query, ' ')}\n)\n" 2224 f"CREATE TABLE {new_table_name} AS\n" 2225 "SELECT *\n" 2226 f"FROM {create_cte_name}" 2227 ), 2228 ] 2229 2230 alter_type_queries = [ 2231 ( 2232 f"ALTER TABLE {new_table_name}\n" 2233 f"ADD PRIMARY KEY ({primary_key_name})" 2234 ), 2235 ] 2236 elif flavor in ('sqlite', 'mysql', 'mariadb', 'duckdb', 'oracle'): 2237 create_table_queries = [ 2238 ( 2239 f"CREATE TABLE {new_table_name} AS\n" 2240 "SELECT *\n" 2241 f"FROM (\n{textwrap.indent(query, ' ')}\n)" 2242 + (f" AS {create_cte_name}" if flavor != 'oracle' else '') 2243 ), 2244 ] 2245 2246 alter_type_queries = [ 2247 ( 2248 f"ALTER TABLE {new_table_name}\n" 2249 "ADD PRIMARY KEY ({primary_key_name})" 2250 ), 2251 ] 2252 elif flavor == 'timescaledb' and datetime_column and datetime_column != primary_key: 2253 create_table_queries = [ 2254 ( 2255 "SELECT *\n" 2256 f"INTO {new_table_name}\n" 2257 f"FROM (\n{textwrap.indent(query, ' ')}\n) AS {create_cte_name}\n" 2258 ), 2259 ] 2260 2261 alter_type_queries = [ 2262 ( 2263 f"ALTER TABLE {new_table_name}\n" 2264 f"ADD PRIMARY KEY ({datetime_column_name}, {primary_key_name})" 2265 ), 2266 ] 2267 else: 2268 create_table_queries = [ 2269 ( 2270 "SELECT *\n" 2271 f"INTO {new_table_name}\n" 2272 f"FROM (\n{textwrap.indent(query, ' ')}\n) AS {create_cte_name}" 2273 ), 2274 ] 2275 2276 alter_type_queries = [ 2277 ( 2278 f"ALTER TABLE {new_table_name}\n" 2279 f"ADD PRIMARY KEY ({primary_key_name})" 2280 ), 2281 ] 2282 2283 if not primary_key: 2284 return create_table_queries 2285 2286 return create_table_queries + alter_type_queries 2287 2288 2289def wrap_query_with_cte( 2290 sub_query: str, 2291 parent_query: str, 2292 flavor: str, 2293 cte_name: str = "src", 2294) -> str: 2295 """ 2296 Wrap a subquery in a CTE and append an encapsulating query. 2297 2298 Parameters 2299 ---------- 2300 sub_query: str 2301 The query to be referenced. This may itself contain CTEs. 2302 Unless `cte_name` is provided, this will be aliased as `src`. 2303 2304 parent_query: str 2305 The larger query to append which references the subquery. 2306 This must not contain CTEs. 2307 2308 flavor: str 2309 The database flavor, e.g. `'mssql'`. 2310 2311 cte_name: str, default 'src' 2312 The CTE alias, defaults to `src`. 2313 2314 Returns 2315 ------- 2316 An encapsulating query which allows you to treat `sub_query` as a temporary table. 2317 2318 Examples 2319 -------- 2320 2321 ```python 2322 from meerschaum.utils.sql import wrap_query_with_cte 2323 sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo" 2324 parent_query = "SELECT newval * 3 FROM src" 2325 query = wrap_query_with_cte(sub_query, parent_query, 'mssql') 2326 print(query) 2327 # WITH foo AS (SELECT 1 AS val), 2328 # [src] AS ( 2329 # SELECT (val * 2) AS newval FROM foo 2330 # ) 2331 # SELECT newval * 3 FROM src 2332 ``` 2333 2334 """ 2335 import textwrap 2336 sub_query = sub_query.lstrip() 2337 cte_name_quoted = sql_item_name(cte_name, flavor, None) 2338 2339 if flavor in NO_CTE_FLAVORS: 2340 return ( 2341 parent_query 2342 .replace(cte_name_quoted, '--MRSM_SUBQUERY--') 2343 .replace(cte_name, '--MRSM_SUBQUERY--') 2344 .replace('--MRSM_SUBQUERY--', f"(\n{sub_query}\n) AS {cte_name_quoted}") 2345 ) 2346 2347 if sub_query.lstrip().lower().startswith('with '): 2348 final_select_ix = sub_query.lower().rfind('select') 2349 return ( 2350 sub_query[:final_select_ix].rstrip() + ',\n' 2351 + f"{cte_name_quoted} AS (\n" 2352 + ' ' + sub_query[final_select_ix:] 2353 + "\n)\n" 2354 + parent_query 2355 ) 2356 2357 return ( 2358 f"WITH {cte_name_quoted} AS (\n" 2359 f"{textwrap.indent(sub_query, ' ')}\n" 2360 f")\n{parent_query}" 2361 ) 2362 2363 2364def format_cte_subquery( 2365 sub_query: str, 2366 flavor: str, 2367 sub_name: str = 'src', 2368 cols_to_select: Union[List[str], str] = '*', 2369) -> str: 2370 """ 2371 Given a subquery, build a wrapper query that selects from the CTE subquery. 2372 2373 Parameters 2374 ---------- 2375 sub_query: str 2376 The subquery to wrap. 2377 2378 flavor: str 2379 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 2380 2381 sub_name: str, default 'src' 2382 If possible, give this name to the CTE (must be unquoted). 2383 2384 cols_to_select: Union[List[str], str], default '' 2385 If specified, choose which columns to select from the CTE. 2386 If a list of strings is provided, each item will be quoted and joined with commas. 2387 If a string is given, assume it is quoted and insert it into the query. 2388 2389 Returns 2390 ------- 2391 A wrapper query that selects from the CTE. 2392 """ 2393 quoted_sub_name = sql_item_name(sub_name, flavor, None) 2394 cols_str = ( 2395 cols_to_select 2396 if isinstance(cols_to_select, str) 2397 else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select]) 2398 ) 2399 parent_query = ( 2400 f"SELECT {cols_str}\n" 2401 f"FROM {quoted_sub_name}" 2402 ) 2403 return wrap_query_with_cte(sub_query, parent_query, flavor, cte_name=sub_name) 2404 2405 2406def session_execute( 2407 session: 'sqlalchemy.orm.session.Session', 2408 queries: Union[List[str], str], 2409 with_results: bool = False, 2410 debug: bool = False, 2411) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]: 2412 """ 2413 Similar to `SQLConnector.exec_queries()`, execute a list of queries 2414 and roll back when one fails. 2415 2416 Parameters 2417 ---------- 2418 session: sqlalchemy.orm.session.Session 2419 A SQLAlchemy session representing a transaction. 2420 2421 queries: Union[List[str], str] 2422 A query or list of queries to be executed. 2423 If a query fails, roll back the session. 2424 2425 with_results: bool, default False 2426 If `True`, return a list of result objects. 2427 2428 Returns 2429 ------- 2430 A `SuccessTuple` indicating the queries were successfully executed. 2431 If `with_results`, return the `SuccessTuple` and a list of results. 2432 """ 2433 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 2434 if not isinstance(queries, list): 2435 queries = [queries] 2436 successes, msgs, results = [], [], [] 2437 for query in queries: 2438 if debug: 2439 dprint(query) 2440 query_text = sqlalchemy.text(query) 2441 fail_msg = "Failed to execute queries." 2442 try: 2443 result = session.execute(query_text) 2444 query_success = result is not None 2445 query_msg = "Success" if query_success else fail_msg 2446 except Exception as e: 2447 query_success = False 2448 query_msg = f"{fail_msg}\n{e}" 2449 result = None 2450 successes.append(query_success) 2451 msgs.append(query_msg) 2452 results.append(result) 2453 if not query_success: 2454 if debug: 2455 dprint("Rolling back session.") 2456 session.rollback() 2457 break 2458 success, msg = all(successes), '\n'.join(msgs) 2459 if with_results: 2460 return (success, msg), results 2461 return success, msg 2462 2463 2464def get_reset_autoincrement_queries( 2465 table: str, 2466 column: str, 2467 connector: mrsm.connectors.SQLConnector, 2468 schema: Optional[str] = None, 2469 debug: bool = False, 2470) -> List[str]: 2471 """ 2472 Return a list of queries to reset a table's auto-increment counter to the next largest value. 2473 2474 Parameters 2475 ---------- 2476 table: str 2477 The name of the table on which the auto-incrementing column exists. 2478 2479 column: str 2480 The name of the auto-incrementing column. 2481 2482 connector: mrsm.connectors.SQLConnector 2483 The SQLConnector to the database on which the table exists. 2484 2485 schema: Optional[str], default None 2486 The schema of the table. Defaults to `connector.schema`. 2487 2488 Returns 2489 ------- 2490 A list of queries to be executed to reset the auto-incrementing column. 2491 """ 2492 if not table_exists(table, connector, schema=schema, debug=debug): 2493 return [] 2494 2495 schema = schema or connector.schema 2496 max_id_name = sql_item_name('max_id', connector.flavor) 2497 table_name = sql_item_name(table, connector.flavor, schema) 2498 table_seq_name = sql_item_name(table + '_' + column + '_seq', connector.flavor, schema) 2499 column_name = sql_item_name(column, connector.flavor) 2500 max_id = connector.value( 2501 f""" 2502 SELECT COALESCE(MAX({column_name}), 0) AS {max_id_name} 2503 FROM {table_name} 2504 """, 2505 debug=debug, 2506 ) 2507 if max_id is None: 2508 return [] 2509 2510 reset_queries = reset_autoincrement_queries.get( 2511 connector.flavor, 2512 reset_autoincrement_queries['default'] 2513 ) 2514 if not isinstance(reset_queries, list): 2515 reset_queries = [reset_queries] 2516 2517 return [ 2518 query.format( 2519 column=column, 2520 column_name=column_name, 2521 table=table, 2522 table_name=table_name, 2523 table_seq_name=table_seq_name, 2524 val=max_id, 2525 val_plus_1=(max_id + 1), 2526 ) 2527 for query in reset_queries 2528 ] 2529 2530 2531def get_postgis_geo_columns_types( 2532 connectable: Union[ 2533 'mrsm.connectors.sql.SQLConnector', 2534 'sqlalchemy.orm.session.Session', 2535 'sqlalchemy.engine.base.Engine' 2536 ], 2537 table: str, 2538 schema: Optional[str] = 'public', 2539 debug: bool = False, 2540) -> Dict[str, str]: 2541 """ 2542 Return a dictionary mapping PostGIS geometry column names to geometry types. 2543 """ 2544 from meerschaum.utils.dtypes import get_geometry_type_srid 2545 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 2546 default_type, default_srid = get_geometry_type_srid() 2547 default_type = default_type.upper() 2548 2549 clean(table) 2550 clean(str(schema)) 2551 schema = schema or 'public' 2552 truncated_schema_name = truncate_item_name(schema, flavor='postgis') 2553 truncated_table_name = truncate_item_name(table, flavor='postgis') 2554 query = sqlalchemy.text( 2555 "SELECT \"f_geometry_column\" AS \"column\", 'GEOMETRY' AS \"func\", \"type\", \"srid\"\n" 2556 "FROM \"geometry_columns\"\n" 2557 f"WHERE \"f_table_schema\" = '{truncated_schema_name}'\n" 2558 f" AND \"f_table_name\" = '{truncated_table_name}'\n" 2559 "UNION ALL\n" 2560 "SELECT \"f_geography_column\" AS \"column\", 'GEOGRAPHY' AS \"func\", \"type\", \"srid\"\n" 2561 "FROM \"geography_columns\"\n" 2562 f"WHERE \"f_table_schema\" = '{truncated_schema_name}'\n" 2563 f" AND \"f_table_name\" = '{truncated_table_name}'\n" 2564 ) 2565 debug_kwargs = {'debug': debug} if isinstance(connectable, mrsm.connectors.SQLConnector) else {} 2566 result_rows = [ 2567 row 2568 for row in connectable.execute(query, **debug_kwargs).fetchall() 2569 ] 2570 cols_type_tuples = { 2571 row[0]: (row[1], row[2], row[3]) 2572 for row in result_rows 2573 } 2574 2575 geometry_cols_types = { 2576 col: ( 2577 f"{func}({typ.upper()}, {srid})" 2578 if srid 2579 else ( 2580 func 2581 + ( 2582 f'({typ.upper()})' 2583 if typ.upper() not in ('GEOMETRY', 'GEOGRAPHY') 2584 else '' 2585 ) 2586 ) 2587 ) 2588 for col, (func, typ, srid) in cols_type_tuples.items() 2589 } 2590 return geometry_cols_types 2591 2592 2593def get_create_schema_if_not_exists_queries( 2594 schema: str, 2595 flavor: str, 2596) -> List[str]: 2597 """ 2598 Return the queries to create a schema if it does not yet exist. 2599 For databases which do not support schemas, an empty list will be returned. 2600 """ 2601 if not schema: 2602 return [] 2603 2604 if flavor in NO_SCHEMA_FLAVORS: 2605 return [] 2606 2607 if schema == DEFAULT_SCHEMA_FLAVORS.get(flavor, None): 2608 return [] 2609 2610 clean(schema) 2611 2612 if flavor == 'mssql': 2613 return [ 2614 ( 2615 f"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{schema}')\n" 2616 "BEGIN\n" 2617 f" EXEC('CREATE SCHEMA {sql_item_name(schema, flavor)}');\n" 2618 "END;" 2619 ) 2620 ] 2621 2622 if flavor == 'oracle': 2623 return [] 2624 2625 return [ 2626 f"CREATE SCHEMA IF NOT EXISTS {sql_item_name(schema, flavor)};" 2627 ]
532def clean(substring: str) -> None: 533 """ 534 Ensure a substring is clean enough to be inserted into a SQL query. 535 Raises an exception when banned words are used. 536 """ 537 from meerschaum.utils.warnings import error 538 banned_symbols = [';', '--', 'drop ',] 539 for symbol in banned_symbols: 540 if symbol in str(substring).lower(): 541 error(f"Invalid string: '{substring}'")
Ensure a substring is clean enough to be inserted into a SQL query. Raises an exception when banned words are used.
544def dateadd_str( 545 flavor: str = 'postgresql', 546 datepart: str = 'day', 547 number: Union[int, float] = 0, 548 begin: Union[str, datetime, int] = 'now', 549 db_type: Optional[str] = None, 550) -> str: 551 """ 552 Generate a `DATEADD` clause depending on database flavor. 553 554 Parameters 555 ---------- 556 flavor: str, default `'postgresql'` 557 SQL database flavor, e.g. `'postgresql'`, `'sqlite'`. 558 559 Currently supported flavors: 560 561 - `'postgresql'` 562 - `'postgis'` 563 - `'timescaledb'` 564 - `'citus'` 565 - `'cockroachdb'` 566 - `'duckdb'` 567 - `'mssql'` 568 - `'mysql'` 569 - `'mariadb'` 570 - `'sqlite'` 571 - `'oracle'` 572 573 datepart: str, default `'day'` 574 Which part of the date to modify. Supported values: 575 576 - `'year'` 577 - `'month'` 578 - `'day'` 579 - `'hour'` 580 - `'minute'` 581 - `'second'` 582 583 number: Union[int, float], default `0` 584 How many units to add to the date part. 585 586 begin: Union[str, datetime], default `'now'` 587 Base datetime to which to add dateparts. 588 589 db_type: Optional[str], default None 590 If provided, cast the datetime string as the type. 591 Otherwise, infer this from the input datetime value. 592 593 Returns 594 ------- 595 The appropriate `DATEADD` string for the corresponding database flavor. 596 597 Examples 598 -------- 599 >>> dateadd_str( 600 ... flavor='mssql', 601 ... begin=datetime(2022, 1, 1, 0, 0), 602 ... number=1, 603 ... ) 604 "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))" 605 >>> dateadd_str( 606 ... flavor='postgresql', 607 ... begin=datetime(2022, 1, 1, 0, 0), 608 ... number=1, 609 ... ) 610 "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'" 611 612 """ 613 from meerschaum.utils.packages import attempt_import 614 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, get_pd_type_from_db_type 615 dateutil_parser = attempt_import('dateutil.parser') 616 if 'int' in str(type(begin)).lower(): 617 num_str = str(begin) 618 if number is not None and number != 0: 619 num_str += ( 620 f' + {number}' 621 if number > 0 622 else f" - {number * -1}" 623 ) 624 return num_str 625 if not begin: 626 return '' 627 628 _original_begin = begin 629 begin_time = None 630 ### Sanity check: make sure `begin` is a valid datetime before we inject anything. 631 if not isinstance(begin, datetime): 632 try: 633 begin_time = dateutil_parser.parse(begin) 634 except Exception: 635 begin_time = None 636 else: 637 begin_time = begin 638 639 ### Unable to parse into a datetime. 640 if begin_time is None: 641 ### Throw an error if banned symbols are included in the `begin` string. 642 clean(str(begin)) 643 ### If begin is a valid datetime, wrap it in quotes. 644 else: 645 if isinstance(begin, datetime) and begin.tzinfo is not None: 646 begin = begin.astimezone(timezone.utc) 647 begin = ( 648 f"'{begin.replace(tzinfo=None)}'" 649 if isinstance(begin, datetime) and flavor in TIMEZONE_NAIVE_FLAVORS 650 else f"'{begin}'" 651 ) 652 653 dt_is_utc = ( 654 begin_time.tzinfo is not None 655 if begin_time is not None 656 else ('+' in str(begin) or '-' in str(begin).split(':', maxsplit=1)[-1]) 657 ) 658 if db_type: 659 db_type_is_utc = 'utc' in get_pd_type_from_db_type(db_type).lower() 660 dt_is_utc = dt_is_utc or db_type_is_utc 661 db_type = db_type or get_db_type_from_pd_type( 662 ('datetime64[ns, UTC]' if dt_is_utc else 'datetime64[ns]'), 663 flavor=flavor, 664 ) 665 666 da = "" 667 if flavor in ('postgresql', 'postgis', 'timescaledb', 'cockroachdb', 'citus'): 668 begin = ( 669 f"CAST({begin} AS {db_type})" if begin != 'now' 670 else f"CAST(NOW() AT TIME ZONE 'utc' AS {db_type})" 671 ) 672 if dt_is_utc: 673 begin += " AT TIME ZONE 'UTC'" 674 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 675 676 elif flavor == 'duckdb': 677 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'NOW()' 678 if dt_is_utc: 679 begin += " AT TIME ZONE 'UTC'" 680 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 681 682 elif flavor in ('mssql',): 683 if begin_time and begin_time.microsecond != 0 and not dt_is_utc: 684 begin = begin[:-4] + "'" 685 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'GETUTCDATE()' 686 if dt_is_utc: 687 begin += " AT TIME ZONE 'UTC'" 688 da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin 689 690 elif flavor in ('mysql', 'mariadb'): 691 begin = ( 692 f"CAST({begin} AS DATETIME(6))" 693 if begin != 'now' 694 else 'UTC_TIMESTAMP(6)' 695 ) 696 da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin) 697 698 elif flavor == 'sqlite': 699 da = f"datetime({begin}, '{number} {datepart}')" 700 701 elif flavor == 'oracle': 702 if begin == 'now': 703 begin = str( 704 datetime.now(timezone.utc).replace(tzinfo=None).strftime(r'%Y:%m:%d %M:%S.%f') 705 ) 706 elif begin_time: 707 begin = str(begin_time.strftime(r'%Y-%m-%d %H:%M:%S.%f')) 708 dt_format = 'YYYY-MM-DD HH24:MI:SS.FF' 709 _begin = f"'{begin}'" if begin_time else begin 710 da = ( 711 (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin) 712 + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "") 713 ) 714 return da
Generate a DATEADD
clause depending on database flavor.
Parameters
flavor (str, default
'postgresql'
): SQL database flavor, e.g.'postgresql'
,'sqlite'
.Currently supported flavors:
'postgresql'
'postgis'
'timescaledb'
'citus'
'cockroachdb'
'duckdb'
'mssql'
'mysql'
'mariadb'
'sqlite'
'oracle'
datepart (str, default
'day'
): Which part of the date to modify. Supported values:'year'
'month'
'day'
'hour'
'minute'
'second'
- number (Union[int, float], default
0
): How many units to add to the date part. - begin (Union[str, datetime], default
'now'
): Base datetime to which to add dateparts. - db_type (Optional[str], default None): If provided, cast the datetime string as the type. Otherwise, infer this from the input datetime value.
Returns
- The appropriate
DATEADD
string for the corresponding database flavor.
Examples
>>> dateadd_str(
... flavor='mssql',
... begin=datetime(2022, 1, 1, 0, 0),
... number=1,
... )
"DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))"
>>> dateadd_str(
... flavor='postgresql',
... begin=datetime(2022, 1, 1, 0, 0),
... number=1,
... )
"CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'"
717def test_connection( 718 self, 719 **kw: Any 720) -> Union[bool, None]: 721 """ 722 Test if a successful connection to the database may be made. 723 724 Parameters 725 ---------- 726 **kw: 727 The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`. 728 729 Returns 730 ------- 731 `True` if a connection is made, otherwise `False` or `None` in case of failure. 732 733 """ 734 import warnings 735 from meerschaum.connectors.poll import retry_connect 736 _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self} 737 _default_kw.update(kw) 738 with warnings.catch_warnings(): 739 warnings.filterwarnings('ignore', 'Could not') 740 try: 741 return retry_connect(**_default_kw) 742 except Exception: 743 return False
Test if a successful connection to the database may be made.
Parameters
- **kw:: The keyword arguments are passed to
meerschaum.connectors.poll.retry_connect
.
Returns
True
if a connection is made, otherwiseFalse
orNone
in case of failure.
746def get_distinct_col_count( 747 col: str, 748 query: str, 749 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 750 debug: bool = False 751) -> Optional[int]: 752 """ 753 Returns the number of distinct items in a column of a SQL query. 754 755 Parameters 756 ---------- 757 col: str: 758 The column in the query to count. 759 760 query: str: 761 The SQL query to count from. 762 763 connector: Optional[mrsm.connectors.sql.SQLConnector], default None: 764 The SQLConnector to execute the query. 765 766 debug: bool, default False: 767 Verbosity toggle. 768 769 Returns 770 ------- 771 An `int` of the number of columns in the query or `None` if the query fails. 772 773 """ 774 if connector is None: 775 connector = mrsm.get_connector('sql') 776 777 _col_name = sql_item_name(col, connector.flavor, None) 778 779 _meta_query = ( 780 f""" 781 WITH src AS ( {query} ), 782 dist AS ( SELECT DISTINCT {_col_name} FROM src ) 783 SELECT COUNT(*) FROM dist""" 784 ) if connector.flavor not in ('mysql', 'mariadb') else ( 785 f""" 786 SELECT COUNT(*) 787 FROM ( 788 SELECT DISTINCT {_col_name} 789 FROM ({query}) AS src 790 ) AS dist""" 791 ) 792 793 result = connector.value(_meta_query, debug=debug) 794 try: 795 return int(result) 796 except Exception: 797 return None
Returns the number of distinct items in a column of a SQL query.
Parameters
- col (str:): The column in the query to count.
- query (str:): The SQL query to count from.
- connector (Optional[mrsm.connectors.sql.SQLConnector], default None:): The SQLConnector to execute the query.
- debug (bool, default False:): Verbosity toggle.
Returns
- An
int
of the number of columns in the query orNone
if the query fails.
800def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str: 801 """ 802 Parse SQL items depending on the flavor. 803 804 Parameters 805 ---------- 806 item: str 807 The database item (table, view, etc.) in need of quotes. 808 809 flavor: str 810 The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.). 811 812 schema: Optional[str], default None 813 If provided, prefix the table name with the schema. 814 815 Returns 816 ------- 817 A `str` which contains the input `item` wrapped in the corresponding escape characters. 818 819 Examples 820 -------- 821 >>> sql_item_name('table', 'sqlite') 822 '"table"' 823 >>> sql_item_name('table', 'mssql') 824 "[table]" 825 >>> sql_item_name('table', 'postgresql', schema='abc') 826 '"abc"."table"' 827 828 """ 829 truncated_item = truncate_item_name(str(item), flavor) 830 if flavor == 'oracle': 831 truncated_item = pg_capital(truncated_item, quote_capitals=True) 832 ### NOTE: System-reserved words must be quoted. 833 if truncated_item.lower() in ( 834 'float', 'varchar', 'nvarchar', 'clob', 835 'boolean', 'integer', 'table', 'row', 836 ): 837 wrappers = ('"', '"') 838 else: 839 wrappers = ('', '') 840 else: 841 wrappers = table_wrappers.get(flavor, table_wrappers['default']) 842 843 ### NOTE: SQLite does not support schemas. 844 if flavor == 'sqlite': 845 schema = None 846 elif flavor == 'mssql' and str(item).startswith('#'): 847 schema = None 848 849 schema_prefix = ( 850 (wrappers[0] + schema + wrappers[1] + '.') 851 if schema is not None 852 else '' 853 ) 854 855 return schema_prefix + wrappers[0] + truncated_item + wrappers[1]
Parse SQL items depending on the flavor.
Parameters
- item (str): The database item (table, view, etc.) in need of quotes.
- flavor (str):
The database flavor (
'postgresql'
,'mssql'
,'sqllite'
, etc.). - schema (Optional[str], default None): If provided, prefix the table name with the schema.
Returns
- A
str
which contains the inputitem
wrapped in the corresponding escape characters.
Examples
>>> sql_item_name('table', 'sqlite')
'"table"'
>>> sql_item_name('table', 'mssql')
"[table]"
>>> sql_item_name('table', 'postgresql', schema='abc')
'"abc"."table"'
858def pg_capital(s: str, quote_capitals: bool = True) -> str: 859 """ 860 If string contains a capital letter, wrap it in double quotes. 861 862 Parameters 863 ---------- 864 s: str 865 The string to be escaped. 866 867 quote_capitals: bool, default True 868 If `False`, do not quote strings with contain only a mix of capital and lower-case letters. 869 870 Returns 871 ------- 872 The input string wrapped in quotes only if it needs them. 873 874 Examples 875 -------- 876 >>> pg_capital("My Table") 877 '"My Table"' 878 >>> pg_capital('my_table') 879 'my_table' 880 881 """ 882 if s.startswith('"') and s.endswith('"'): 883 return s 884 885 s = s.replace('"', '') 886 887 needs_quotes = s.startswith('_') 888 if not needs_quotes: 889 for c in s: 890 if c == '_': 891 continue 892 893 if not c.isalnum() or (quote_capitals and c.isupper()): 894 needs_quotes = True 895 break 896 897 if needs_quotes: 898 return '"' + s + '"' 899 900 return s
If string contains a capital letter, wrap it in double quotes.
Parameters
- s (str): The string to be escaped.
- quote_capitals (bool, default True):
If
False
, do not quote strings with contain only a mix of capital and lower-case letters.
Returns
- The input string wrapped in quotes only if it needs them.
Examples
>>> pg_capital("My Table")
'"My Table"'
>>> pg_capital('my_table')
'my_table'
903def oracle_capital(s: str) -> str: 904 """ 905 Capitalize the string of an item on an Oracle database. 906 """ 907 return s
Capitalize the string of an item on an Oracle database.
910def truncate_item_name(item: str, flavor: str) -> str: 911 """ 912 Truncate item names to stay within the database flavor's character limit. 913 914 Parameters 915 ---------- 916 item: str 917 The database item being referenced. This string is the "canonical" name internally. 918 919 flavor: str 920 The flavor of the database on which `item` resides. 921 922 Returns 923 ------- 924 The truncated string. 925 """ 926 from meerschaum.utils.misc import truncate_string_sections 927 return truncate_string_sections( 928 item, max_len=max_name_lens.get(flavor, max_name_lens['default']) 929 )
Truncate item names to stay within the database flavor's character limit.
Parameters
- item (str): The database item being referenced. This string is the "canonical" name internally.
- flavor (str):
The flavor of the database on which
item
resides.
Returns
- The truncated string.
932def build_where( 933 params: Dict[str, Any], 934 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 935 with_where: bool = True, 936 flavor: str = 'postgresql', 937) -> str: 938 """ 939 Build the `WHERE` clause based on the input criteria. 940 941 Parameters 942 ---------- 943 params: Dict[str, Any]: 944 The keywords dictionary to convert into a WHERE clause. 945 If a value is a string which begins with an underscore, negate that value 946 (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`). 947 A value of `_None` will be interpreted as `IS NOT NULL`. 948 949 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 950 The Meerschaum SQLConnector that will be executing the query. 951 The connector is used to extract the SQL dialect. 952 953 with_where: bool, default True: 954 If `True`, include the leading `'WHERE'` string. 955 956 flavor: str, default 'postgresql' 957 If `connector` is `None`, fall back to this flavor. 958 959 Returns 960 ------- 961 A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor. 962 963 Examples 964 -------- 965 ``` 966 >>> print(build_where({'foo': [1, 2, 3]})) 967 968 WHERE 969 "foo" IN ('1', '2', '3') 970 ``` 971 """ 972 import json 973 from meerschaum.config.static import STATIC_CONFIG 974 from meerschaum.utils.warnings import warn 975 from meerschaum.utils.dtypes import value_is_null, none_if_null 976 negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix'] 977 try: 978 params_json = json.dumps(params) 979 except Exception: 980 params_json = str(params) 981 bad_words = ['drop ', '--', ';'] 982 for word in bad_words: 983 if word in params_json.lower(): 984 warn("Aborting build_where() due to possible SQL injection.") 985 return '' 986 987 query_flavor = getattr(connector, 'flavor', flavor) if connector is not None else flavor 988 where = "" 989 leading_and = "\n AND " 990 for key, value in params.items(): 991 _key = sql_item_name(key, query_flavor, None) 992 ### search across a list (i.e. IN syntax) 993 if isinstance(value, Iterable) and not isinstance(value, (dict, str)): 994 includes = [ 995 none_if_null(item) 996 for item in value 997 if not str(item).startswith(negation_prefix) 998 ] 999 null_includes = [item for item in includes if item is None] 1000 not_null_includes = [item for item in includes if item is not None] 1001 excludes = [ 1002 none_if_null(str(item)[len(negation_prefix):]) 1003 for item in value 1004 if str(item).startswith(negation_prefix) 1005 ] 1006 null_excludes = [item for item in excludes if item is None] 1007 not_null_excludes = [item for item in excludes if item is not None] 1008 1009 if includes: 1010 where += f"{leading_and}(" 1011 if not_null_includes: 1012 where += f"{_key} IN (" 1013 for item in not_null_includes: 1014 quoted_item = str(item).replace("'", "''") 1015 where += f"'{quoted_item}', " 1016 where = where[:-2] + ")" 1017 if null_includes: 1018 where += ("\n OR " if not_null_includes else "") + f"{_key} IS NULL" 1019 if includes: 1020 where += ")" 1021 1022 if excludes: 1023 where += f"{leading_and}(" 1024 if not_null_excludes: 1025 where += f"{_key} NOT IN (" 1026 for item in not_null_excludes: 1027 quoted_item = str(item).replace("'", "''") 1028 where += f"'{quoted_item}', " 1029 where = where[:-2] + ")" 1030 if null_excludes: 1031 where += ("\n AND " if not_null_excludes else "") + f"{_key} IS NOT NULL" 1032 if excludes: 1033 where += ")" 1034 1035 continue 1036 1037 ### search a dictionary 1038 elif isinstance(value, dict): 1039 import json 1040 where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'") 1041 continue 1042 1043 eq_sign = '=' 1044 is_null = 'IS NULL' 1045 if value_is_null(str(value).lstrip(negation_prefix)): 1046 value = ( 1047 (negation_prefix + 'None') 1048 if str(value).startswith(negation_prefix) 1049 else None 1050 ) 1051 if str(value).startswith(negation_prefix): 1052 value = str(value)[len(negation_prefix):] 1053 eq_sign = '!=' 1054 if value_is_null(value): 1055 value = None 1056 is_null = 'IS NOT NULL' 1057 quoted_value = str(value).replace("'", "''") 1058 where += ( 1059 f"{leading_and}{_key} " 1060 + (is_null if value is None else f"{eq_sign} '{quoted_value}'") 1061 ) 1062 1063 if len(where) > 1: 1064 where = ("\nWHERE\n " if with_where else '') + where[len(leading_and):] 1065 return where
Build the WHERE
clause based on the input criteria.
Parameters
- params (Dict[str, Any]:):
The keywords dictionary to convert into a WHERE clause.
If a value is a string which begins with an underscore, negate that value
(e.g.
!=
instead of=
orNOT IN
instead ofIN
). A value of_None
will be interpreted asIS NOT NULL
. - connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The Meerschaum SQLConnector that will be executing the query. The connector is used to extract the SQL dialect.
- with_where (bool, default True:):
If
True
, include the leading'WHERE'
string. - flavor (str, default 'postgresql'):
If
connector
isNone
, fall back to this flavor.
Returns
- A
str
of theWHERE
clause from the inputparams
dictionary for the connector's flavor.
Examples
>>> print(build_where({'foo': [1, 2, 3]}))
WHERE
"foo" IN ('1', '2', '3')
1068def table_exists( 1069 table: str, 1070 connector: mrsm.connectors.sql.SQLConnector, 1071 schema: Optional[str] = None, 1072 debug: bool = False, 1073) -> bool: 1074 """Check if a table exists. 1075 1076 Parameters 1077 ---------- 1078 table: str: 1079 The name of the table in question. 1080 1081 connector: mrsm.connectors.sql.SQLConnector 1082 The connector to the database which holds the table. 1083 1084 schema: Optional[str], default None 1085 Optionally specify the table schema. 1086 Defaults to `connector.schema`. 1087 1088 debug: bool, default False : 1089 Verbosity toggle. 1090 1091 Returns 1092 ------- 1093 A `bool` indicating whether or not the table exists on the database. 1094 """ 1095 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1096 schema = schema or connector.schema 1097 insp = sqlalchemy.inspect(connector.engine) 1098 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1099 return insp.has_table(truncated_table_name, schema=schema)
Check if a table exists.
Parameters
- table (str:): The name of the table in question.
- connector (mrsm.connectors.sql.SQLConnector): The connector to the database which holds the table.
- schema (Optional[str], default None):
Optionally specify the table schema.
Defaults to
connector.schema
. - debug (bool, default False :): Verbosity toggle.
Returns
- A
bool
indicating whether or not the table exists on the database.
1102def get_sqlalchemy_table( 1103 table: str, 1104 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 1105 schema: Optional[str] = None, 1106 refresh: bool = False, 1107 debug: bool = False, 1108) -> Union['sqlalchemy.Table', None]: 1109 """ 1110 Construct a SQLAlchemy table from its name. 1111 1112 Parameters 1113 ---------- 1114 table: str 1115 The name of the table on the database. Does not need to be escaped. 1116 1117 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 1118 The connector to the database which holds the table. 1119 1120 schema: Optional[str], default None 1121 Specify on which schema the table resides. 1122 Defaults to the schema set in `connector`. 1123 1124 refresh: bool, default False 1125 If `True`, rebuild the cached table object. 1126 1127 debug: bool, default False: 1128 Verbosity toggle. 1129 1130 Returns 1131 ------- 1132 A `sqlalchemy.Table` object for the table. 1133 1134 """ 1135 if connector is None: 1136 from meerschaum import get_connector 1137 connector = get_connector('sql') 1138 1139 if connector.flavor == 'duckdb': 1140 return None 1141 1142 from meerschaum.connectors.sql.tables import get_tables 1143 from meerschaum.utils.packages import attempt_import 1144 from meerschaum.utils.warnings import warn 1145 if refresh: 1146 connector.metadata.clear() 1147 tables = get_tables(mrsm_instance=connector, debug=debug, create=False) 1148 sqlalchemy = attempt_import('sqlalchemy', lazy=False) 1149 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1150 table_kwargs = { 1151 'autoload_with': connector.engine, 1152 } 1153 if schema: 1154 table_kwargs['schema'] = schema 1155 1156 if refresh or truncated_table_name not in tables: 1157 try: 1158 tables[truncated_table_name] = sqlalchemy.Table( 1159 truncated_table_name, 1160 connector.metadata, 1161 **table_kwargs 1162 ) 1163 except sqlalchemy.exc.NoSuchTableError: 1164 warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.") 1165 return None 1166 return tables[truncated_table_name]
Construct a SQLAlchemy table from its name.
Parameters
- table (str): The name of the table on the database. Does not need to be escaped.
- connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The connector to the database which holds the table.
- schema (Optional[str], default None):
Specify on which schema the table resides.
Defaults to the schema set in
connector
. - refresh (bool, default False):
If
True
, rebuild the cached table object. - debug (bool, default False:): Verbosity toggle.
Returns
- A
sqlalchemy.Table
object for the table.
1169def get_table_cols_types( 1170 table: str, 1171 connectable: Union[ 1172 'mrsm.connectors.sql.SQLConnector', 1173 'sqlalchemy.orm.session.Session', 1174 'sqlalchemy.engine.base.Engine' 1175 ], 1176 flavor: Optional[str] = None, 1177 schema: Optional[str] = None, 1178 database: Optional[str] = None, 1179 debug: bool = False, 1180) -> Dict[str, str]: 1181 """ 1182 Return a dictionary mapping a table's columns to data types. 1183 This is useful for inspecting tables creating during a not-yet-committed session. 1184 1185 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1186 Use this function if you are confident the table name is unique or if you have 1187 and explicit schema. 1188 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1189 1190 Parameters 1191 ---------- 1192 table: str 1193 The name of the table (unquoted). 1194 1195 connectable: Union[ 1196 'mrsm.connectors.sql.SQLConnector', 1197 'sqlalchemy.orm.session.Session', 1198 'sqlalchemy.engine.base.Engine' 1199 ] 1200 The connection object used to fetch the columns and types. 1201 1202 flavor: Optional[str], default None 1203 The database dialect flavor to use for the query. 1204 If omitted, default to `connectable.flavor`. 1205 1206 schema: Optional[str], default None 1207 If provided, restrict the query to this schema. 1208 1209 database: Optional[str]. default None 1210 If provided, restrict the query to this database. 1211 1212 Returns 1213 ------- 1214 A dictionary mapping column names to data types. 1215 """ 1216 import textwrap 1217 from meerschaum.connectors import SQLConnector 1218 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1219 flavor = flavor or getattr(connectable, 'flavor', None) 1220 if not flavor: 1221 raise ValueError("Please provide a database flavor.") 1222 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1223 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1224 if flavor in NO_SCHEMA_FLAVORS: 1225 schema = None 1226 if schema is None: 1227 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1228 if flavor in ('sqlite', 'duckdb', 'oracle'): 1229 database = None 1230 table_trunc = truncate_item_name(table, flavor=flavor) 1231 table_lower = table.lower() 1232 table_upper = table.upper() 1233 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1234 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1235 db_prefix = ( 1236 "tempdb." 1237 if flavor == 'mssql' and table.startswith('#') 1238 else "" 1239 ) 1240 1241 cols_types_query = sqlalchemy.text( 1242 textwrap.dedent(columns_types_queries.get( 1243 flavor, 1244 columns_types_queries['default'] 1245 ).format( 1246 table=table, 1247 table_trunc=table_trunc, 1248 table_lower=table_lower, 1249 table_lower_trunc=table_lower_trunc, 1250 table_upper=table_upper, 1251 table_upper_trunc=table_upper_trunc, 1252 db_prefix=db_prefix, 1253 )).lstrip().rstrip() 1254 ) 1255 1256 cols = ['database', 'schema', 'table', 'column', 'type', 'numeric_precision', 'numeric_scale'] 1257 result_cols_ix = dict(enumerate(cols)) 1258 1259 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1260 if not debug_kwargs and debug: 1261 dprint(cols_types_query) 1262 1263 try: 1264 result_rows = ( 1265 [ 1266 row 1267 for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall() 1268 ] 1269 if flavor != 'duckdb' 1270 else [ 1271 tuple([doc[col] for col in cols]) 1272 for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records') 1273 ] 1274 ) 1275 cols_types_docs = [ 1276 { 1277 result_cols_ix[i]: val 1278 for i, val in enumerate(row) 1279 } 1280 for row in result_rows 1281 ] 1282 cols_types_docs_filtered = [ 1283 doc 1284 for doc in cols_types_docs 1285 if ( 1286 ( 1287 not schema 1288 or doc['schema'] == schema 1289 ) 1290 and 1291 ( 1292 not database 1293 or doc['database'] == database 1294 ) 1295 ) 1296 ] 1297 1298 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1299 if cols_types_docs and not cols_types_docs_filtered: 1300 cols_types_docs_filtered = cols_types_docs 1301 1302 ### NOTE: Check for PostGIS GEOMETRY columns. 1303 geometry_cols_types = {} 1304 user_defined_cols = [ 1305 doc 1306 for doc in cols_types_docs_filtered 1307 if str(doc.get('type', None)).upper() == 'USER-DEFINED' 1308 ] 1309 if user_defined_cols: 1310 geometry_cols_types.update( 1311 get_postgis_geo_columns_types( 1312 connectable, 1313 table, 1314 schema=schema, 1315 debug=debug, 1316 ) 1317 ) 1318 1319 cols_types = { 1320 ( 1321 doc['column'] 1322 if flavor != 'oracle' else ( 1323 ( 1324 doc['column'].lower() 1325 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1326 else doc['column'] 1327 ) 1328 ) 1329 ): doc['type'].upper() + ( 1330 f'({precision},{scale})' 1331 if ( 1332 (precision := doc.get('numeric_precision', None)) 1333 and 1334 (scale := doc.get('numeric_scale', None)) 1335 ) 1336 else '' 1337 ) 1338 for doc in cols_types_docs_filtered 1339 } 1340 cols_types.update(geometry_cols_types) 1341 return cols_types 1342 except Exception as e: 1343 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1344 return {}
Return a dictionary mapping a table's columns to data types. This is useful for inspecting tables creating during a not-yet-committed session.
NOTE: This may return incorrect columns if the schema is not explicitly stated.
Use this function if you are confident the table name is unique or if you have
and explicit schema.
To use the configured schema, get the columns from get_sqlalchemy_table()
instead.
Parameters
- table (str): The name of the table (unquoted).
- connectable (Union[): 'mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine'
- ]: The connection object used to fetch the columns and types.
- flavor (Optional[str], default None):
The database dialect flavor to use for the query.
If omitted, default to
connectable.flavor
. - schema (Optional[str], default None): If provided, restrict the query to this schema.
- database (Optional[str]. default None): If provided, restrict the query to this database.
Returns
- A dictionary mapping column names to data types.
1347def get_table_cols_indices( 1348 table: str, 1349 connectable: Union[ 1350 'mrsm.connectors.sql.SQLConnector', 1351 'sqlalchemy.orm.session.Session', 1352 'sqlalchemy.engine.base.Engine' 1353 ], 1354 flavor: Optional[str] = None, 1355 schema: Optional[str] = None, 1356 database: Optional[str] = None, 1357 debug: bool = False, 1358) -> Dict[str, List[str]]: 1359 """ 1360 Return a dictionary mapping a table's columns to lists of indices. 1361 This is useful for inspecting tables creating during a not-yet-committed session. 1362 1363 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1364 Use this function if you are confident the table name is unique or if you have 1365 and explicit schema. 1366 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1367 1368 Parameters 1369 ---------- 1370 table: str 1371 The name of the table (unquoted). 1372 1373 connectable: Union[ 1374 'mrsm.connectors.sql.SQLConnector', 1375 'sqlalchemy.orm.session.Session', 1376 'sqlalchemy.engine.base.Engine' 1377 ] 1378 The connection object used to fetch the columns and types. 1379 1380 flavor: Optional[str], default None 1381 The database dialect flavor to use for the query. 1382 If omitted, default to `connectable.flavor`. 1383 1384 schema: Optional[str], default None 1385 If provided, restrict the query to this schema. 1386 1387 database: Optional[str]. default None 1388 If provided, restrict the query to this database. 1389 1390 Returns 1391 ------- 1392 A dictionary mapping column names to a list of indices. 1393 """ 1394 import textwrap 1395 from collections import defaultdict 1396 from meerschaum.connectors import SQLConnector 1397 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1398 flavor = flavor or getattr(connectable, 'flavor', None) 1399 if not flavor: 1400 raise ValueError("Please provide a database flavor.") 1401 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1402 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1403 if flavor in NO_SCHEMA_FLAVORS: 1404 schema = None 1405 if schema is None: 1406 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1407 if flavor in ('sqlite', 'duckdb', 'oracle'): 1408 database = None 1409 table_trunc = truncate_item_name(table, flavor=flavor) 1410 table_lower = table.lower() 1411 table_upper = table.upper() 1412 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1413 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1414 db_prefix = ( 1415 "tempdb." 1416 if flavor == 'mssql' and table.startswith('#') 1417 else "" 1418 ) 1419 1420 cols_indices_query = sqlalchemy.text( 1421 textwrap.dedent(columns_indices_queries.get( 1422 flavor, 1423 columns_indices_queries['default'] 1424 ).format( 1425 table=table, 1426 table_trunc=table_trunc, 1427 table_lower=table_lower, 1428 table_lower_trunc=table_lower_trunc, 1429 table_upper=table_upper, 1430 table_upper_trunc=table_upper_trunc, 1431 db_prefix=db_prefix, 1432 schema=schema, 1433 )).lstrip().rstrip() 1434 ) 1435 1436 cols = ['database', 'schema', 'table', 'column', 'index', 'index_type'] 1437 if flavor == 'mssql': 1438 cols.append('clustered') 1439 result_cols_ix = dict(enumerate(cols)) 1440 1441 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1442 if not debug_kwargs and debug: 1443 dprint(cols_indices_query) 1444 1445 try: 1446 result_rows = ( 1447 [ 1448 row 1449 for row in connectable.execute(cols_indices_query, **debug_kwargs).fetchall() 1450 ] 1451 if flavor != 'duckdb' 1452 else [ 1453 tuple([doc[col] for col in cols]) 1454 for doc in connectable.read(cols_indices_query, debug=debug).to_dict(orient='records') 1455 ] 1456 ) 1457 cols_types_docs = [ 1458 { 1459 result_cols_ix[i]: val 1460 for i, val in enumerate(row) 1461 } 1462 for row in result_rows 1463 ] 1464 cols_types_docs_filtered = [ 1465 doc 1466 for doc in cols_types_docs 1467 if ( 1468 ( 1469 not schema 1470 or doc['schema'] == schema 1471 ) 1472 and 1473 ( 1474 not database 1475 or doc['database'] == database 1476 ) 1477 ) 1478 ] 1479 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1480 if cols_types_docs and not cols_types_docs_filtered: 1481 cols_types_docs_filtered = cols_types_docs 1482 1483 cols_indices = defaultdict(lambda: []) 1484 for doc in cols_types_docs_filtered: 1485 col = ( 1486 doc['column'] 1487 if flavor != 'oracle' 1488 else ( 1489 doc['column'].lower() 1490 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1491 else doc['column'] 1492 ) 1493 ) 1494 index_doc = { 1495 'name': doc.get('index', None), 1496 'type': doc.get('index_type', None) 1497 } 1498 if flavor == 'mssql': 1499 index_doc['clustered'] = doc.get('clustered', None) 1500 cols_indices[col].append(index_doc) 1501 1502 return dict(cols_indices) 1503 except Exception as e: 1504 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1505 return {}
Return a dictionary mapping a table's columns to lists of indices. This is useful for inspecting tables creating during a not-yet-committed session.
NOTE: This may return incorrect columns if the schema is not explicitly stated.
Use this function if you are confident the table name is unique or if you have
and explicit schema.
To use the configured schema, get the columns from get_sqlalchemy_table()
instead.
Parameters
- table (str): The name of the table (unquoted).
- connectable (Union[): 'mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine'
- ]: The connection object used to fetch the columns and types.
- flavor (Optional[str], default None):
The database dialect flavor to use for the query.
If omitted, default to
connectable.flavor
. - schema (Optional[str], default None): If provided, restrict the query to this schema.
- database (Optional[str]. default None): If provided, restrict the query to this database.
Returns
- A dictionary mapping column names to a list of indices.
1508def get_update_queries( 1509 target: str, 1510 patch: str, 1511 connectable: Union[ 1512 mrsm.connectors.sql.SQLConnector, 1513 'sqlalchemy.orm.session.Session' 1514 ], 1515 join_cols: Iterable[str], 1516 flavor: Optional[str] = None, 1517 upsert: bool = False, 1518 datetime_col: Optional[str] = None, 1519 schema: Optional[str] = None, 1520 patch_schema: Optional[str] = None, 1521 identity_insert: bool = False, 1522 null_indices: bool = True, 1523 cast_columns: bool = True, 1524 debug: bool = False, 1525) -> List[str]: 1526 """ 1527 Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table. 1528 1529 Parameters 1530 ---------- 1531 target: str 1532 The name of the target table. 1533 1534 patch: str 1535 The name of the patch table. This should have the same shape as the target. 1536 1537 connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session] 1538 The `SQLConnector` or SQLAlchemy session which will later execute the queries. 1539 1540 join_cols: List[str] 1541 The columns to use to join the patch to the target. 1542 1543 flavor: Optional[str], default None 1544 If using a SQLAlchemy session, provide the expected database flavor. 1545 1546 upsert: bool, default False 1547 If `True`, return an upsert query rather than an update. 1548 1549 datetime_col: Optional[str], default None 1550 If provided, bound the join query using this column as the datetime index. 1551 This must be present on both tables. 1552 1553 schema: Optional[str], default None 1554 If provided, use this schema when quoting the target table. 1555 Defaults to `connector.schema`. 1556 1557 patch_schema: Optional[str], default None 1558 If provided, use this schema when quoting the patch table. 1559 Defaults to `schema`. 1560 1561 identity_insert: bool, default False 1562 If `True`, include `SET IDENTITY_INSERT` queries before and after the update queries. 1563 Only applies for MSSQL upserts. 1564 1565 null_indices: bool, default True 1566 If `False`, do not coalesce index columns before joining. 1567 1568 cast_columns: bool, default True 1569 If `False`, do not cast update columns to the target table types. 1570 1571 debug: bool, default False 1572 Verbosity toggle. 1573 1574 Returns 1575 ------- 1576 A list of query strings to perform the update operation. 1577 """ 1578 import textwrap 1579 from meerschaum.connectors import SQLConnector 1580 from meerschaum.utils.debug import dprint 1581 from meerschaum.utils.dtypes import are_dtypes_equal 1582 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES, get_pd_type_from_db_type 1583 flavor = flavor or getattr(connectable, 'flavor', None) 1584 if not flavor: 1585 raise ValueError("Provide a flavor if using a SQLAlchemy session.") 1586 if ( 1587 flavor == 'sqlite' 1588 and isinstance(connectable, SQLConnector) 1589 and connectable.db_version < '3.33.0' 1590 ): 1591 flavor = 'sqlite_delete_insert' 1592 flavor_key = (f'{flavor}-upsert' if upsert else flavor) 1593 base_queries = UPDATE_QUERIES.get( 1594 flavor_key, 1595 UPDATE_QUERIES['default'] 1596 ) 1597 if not isinstance(base_queries, list): 1598 base_queries = [base_queries] 1599 schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None) 1600 patch_schema = patch_schema or schema 1601 target_table_columns = get_table_cols_types( 1602 target, 1603 connectable, 1604 flavor=flavor, 1605 schema=schema, 1606 debug=debug, 1607 ) 1608 patch_table_columns = get_table_cols_types( 1609 patch, 1610 connectable, 1611 flavor=flavor, 1612 schema=patch_schema, 1613 debug=debug, 1614 ) 1615 1616 patch_cols_str = ', '.join( 1617 [ 1618 sql_item_name(col, flavor) 1619 for col in patch_table_columns 1620 ] 1621 ) 1622 patch_cols_prefixed_str = ', '.join( 1623 [ 1624 'p.' + sql_item_name(col, flavor) 1625 for col in patch_table_columns 1626 ] 1627 ) 1628 1629 join_cols_str = ', '.join( 1630 [ 1631 sql_item_name(col, flavor) 1632 for col in join_cols 1633 ] 1634 ) 1635 1636 value_cols = [] 1637 join_cols_types = [] 1638 if debug: 1639 dprint("target_table_columns:") 1640 mrsm.pprint(target_table_columns) 1641 for c_name, c_type in target_table_columns.items(): 1642 if c_name not in patch_table_columns: 1643 continue 1644 if flavor in DB_FLAVORS_CAST_DTYPES: 1645 c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type) 1646 ( 1647 join_cols_types 1648 if c_name in join_cols 1649 else value_cols 1650 ).append((c_name, c_type)) 1651 if debug: 1652 dprint(f"value_cols: {value_cols}") 1653 1654 if not join_cols_types: 1655 return [] 1656 if not value_cols and not upsert: 1657 return [] 1658 1659 coalesce_join_cols_str = ', '.join( 1660 [ 1661 ( 1662 ( 1663 'COALESCE(' 1664 + sql_item_name(c_name, flavor) 1665 + ', ' 1666 + get_null_replacement(c_type, flavor) 1667 + ')' 1668 ) 1669 if null_indices 1670 else sql_item_name(c_name, flavor) 1671 ) 1672 for c_name, c_type in join_cols_types 1673 ] 1674 ) 1675 1676 update_or_nothing = ('UPDATE' if value_cols else 'NOTHING') 1677 1678 def sets_subquery(l_prefix: str, r_prefix: str): 1679 if not value_cols: 1680 return '' 1681 1682 utc_value_cols = { 1683 c_name 1684 for c_name, c_type in value_cols 1685 if ('utc' in get_pd_type_from_db_type(c_type).lower()) 1686 } if flavor not in TIMEZONE_NAIVE_FLAVORS else set() 1687 1688 cast_func_cols = { 1689 c_name: ( 1690 ('', '', '') 1691 if not cast_columns or ( 1692 flavor == 'oracle' 1693 and are_dtypes_equal(get_pd_type_from_db_type(c_type), 'bytes') 1694 ) 1695 else ( 1696 ('CAST(', f" AS {c_type.replace('_', ' ')}", ')' + ( 1697 " AT TIME ZONE 'UTC'" 1698 if c_name in utc_value_cols 1699 else '' 1700 )) 1701 if flavor != 'sqlite' 1702 else ('', '', '') 1703 ) 1704 ) 1705 for c_name, c_type in value_cols 1706 } 1707 return 'SET ' + ',\n'.join([ 1708 ( 1709 l_prefix + sql_item_name(c_name, flavor, None) 1710 + ' = ' 1711 + cast_func_cols[c_name][0] 1712 + r_prefix + sql_item_name(c_name, flavor, None) 1713 + cast_func_cols[c_name][1] 1714 + cast_func_cols[c_name][2] 1715 ) for c_name, c_type in value_cols 1716 ]) 1717 1718 def and_subquery(l_prefix: str, r_prefix: str): 1719 return '\n AND\n '.join([ 1720 ( 1721 ( 1722 "COALESCE(" 1723 + l_prefix 1724 + sql_item_name(c_name, flavor, None) 1725 + ", " 1726 + get_null_replacement(c_type, flavor) 1727 + ")" 1728 + '\n =\n ' 1729 + "COALESCE(" 1730 + r_prefix 1731 + sql_item_name(c_name, flavor, None) 1732 + ", " 1733 + get_null_replacement(c_type, flavor) 1734 + ")" 1735 ) 1736 if null_indices 1737 else ( 1738 l_prefix 1739 + sql_item_name(c_name, flavor, None) 1740 + ' = ' 1741 + r_prefix 1742 + sql_item_name(c_name, flavor, None) 1743 ) 1744 ) for c_name, c_type in join_cols_types 1745 ]) 1746 1747 skip_query_val = "" 1748 target_table_name = sql_item_name(target, flavor, schema) 1749 patch_table_name = sql_item_name(patch, flavor, patch_schema) 1750 dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None 1751 date_bounds_table = patch_table_name if flavor != 'mssql' else '[date_bounds]' 1752 min_dt_col_name = f"MIN({dt_col_name})" if flavor != 'mssql' else '[Min_dt]' 1753 max_dt_col_name = f"MAX({dt_col_name})" if flavor != 'mssql' else '[Max_dt]' 1754 date_bounds_subquery = ( 1755 f"""f.{dt_col_name} >= (SELECT {min_dt_col_name} FROM {date_bounds_table}) 1756 AND 1757 f.{dt_col_name} <= (SELECT {max_dt_col_name} FROM {date_bounds_table})""" 1758 if datetime_col 1759 else "1 = 1" 1760 ) 1761 with_temp_date_bounds = f"""WITH [date_bounds] AS ( 1762 SELECT MIN({dt_col_name}) AS {min_dt_col_name}, MAX({dt_col_name}) AS {max_dt_col_name} 1763 FROM {patch_table_name} 1764 )""" if datetime_col else "" 1765 identity_insert_on = ( 1766 f"SET IDENTITY_INSERT {target_table_name} ON" 1767 if identity_insert 1768 else skip_query_val 1769 ) 1770 identity_insert_off = ( 1771 f"SET IDENTITY_INSERT {target_table_name} OFF" 1772 if identity_insert 1773 else skip_query_val 1774 ) 1775 1776 ### NOTE: MSSQL upserts must exclude the update portion if only upserting indices. 1777 when_matched_update_sets_subquery_none = "" if not value_cols else ( 1778 "\n WHEN MATCHED THEN\n" 1779 f" UPDATE {sets_subquery('', 'p.')}" 1780 ) 1781 1782 cols_equal_values = '\n,'.join( 1783 [ 1784 f"{sql_item_name(c_name, flavor)} = VALUES({sql_item_name(c_name, flavor)})" 1785 for c_name, c_type in value_cols 1786 ] 1787 ) 1788 on_duplicate_key_update = ( 1789 "ON DUPLICATE KEY UPDATE" 1790 if value_cols 1791 else "" 1792 ) 1793 ignore = "IGNORE " if not value_cols else "" 1794 1795 formatted_queries = [ 1796 textwrap.dedent(base_query.format( 1797 sets_subquery_none=sets_subquery('', 'p.'), 1798 sets_subquery_none_excluded=sets_subquery('', 'EXCLUDED.'), 1799 sets_subquery_f=sets_subquery('f.', 'p.'), 1800 and_subquery_f=and_subquery('p.', 'f.'), 1801 and_subquery_t=and_subquery('p.', 't.'), 1802 target_table_name=target_table_name, 1803 patch_table_name=patch_table_name, 1804 patch_cols_str=patch_cols_str, 1805 patch_cols_prefixed_str=patch_cols_prefixed_str, 1806 date_bounds_subquery=date_bounds_subquery, 1807 join_cols_str=join_cols_str, 1808 coalesce_join_cols_str=coalesce_join_cols_str, 1809 update_or_nothing=update_or_nothing, 1810 when_matched_update_sets_subquery_none=when_matched_update_sets_subquery_none, 1811 cols_equal_values=cols_equal_values, 1812 on_duplicate_key_update=on_duplicate_key_update, 1813 ignore=ignore, 1814 with_temp_date_bounds=with_temp_date_bounds, 1815 identity_insert_on=identity_insert_on, 1816 identity_insert_off=identity_insert_off, 1817 )).lstrip().rstrip() 1818 for base_query in base_queries 1819 ] 1820 1821 ### NOTE: Allow for skipping some queries. 1822 return [query for query in formatted_queries if query]
Build a list of MERGE
, UPDATE
, DELETE
/INSERT
queries to apply a patch to target table.
Parameters
- target (str): The name of the target table.
- patch (str): The name of the patch table. This should have the same shape as the target.
- connectable (Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session]):
The
SQLConnector
or SQLAlchemy session which will later execute the queries. - join_cols (List[str]): The columns to use to join the patch to the target.
- flavor (Optional[str], default None): If using a SQLAlchemy session, provide the expected database flavor.
- upsert (bool, default False):
If
True
, return an upsert query rather than an update. - datetime_col (Optional[str], default None): If provided, bound the join query using this column as the datetime index. This must be present on both tables.
- schema (Optional[str], default None):
If provided, use this schema when quoting the target table.
Defaults to
connector.schema
. - patch_schema (Optional[str], default None):
If provided, use this schema when quoting the patch table.
Defaults to
schema
. - identity_insert (bool, default False):
If
True
, includeSET IDENTITY_INSERT
queries before and after the update queries. Only applies for MSSQL upserts. - null_indices (bool, default True):
If
False
, do not coalesce index columns before joining. - cast_columns (bool, default True):
If
False
, do not cast update columns to the target table types. - debug (bool, default False): Verbosity toggle.
Returns
- A list of query strings to perform the update operation.
1825def get_null_replacement(typ: str, flavor: str) -> str: 1826 """ 1827 Return a value that may temporarily be used in place of NULL for this type. 1828 1829 Parameters 1830 ---------- 1831 typ: str 1832 The typ to be converted to NULL. 1833 1834 flavor: str 1835 The database flavor for which this value will be used. 1836 1837 Returns 1838 ------- 1839 A value which may stand in place of NULL for this type. 1840 `'None'` is returned if a value cannot be determined. 1841 """ 1842 from meerschaum.utils.dtypes import are_dtypes_equal 1843 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1844 if 'geometry' in typ.lower(): 1845 return '010100000000008058346FCDC100008058346FCDC1' 1846 if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'): 1847 return '-987654321' 1848 if 'bool' in typ.lower() or typ.lower() == 'bit': 1849 bool_typ = ( 1850 PD_TO_DB_DTYPES_FLAVORS 1851 .get('bool', {}) 1852 .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default']) 1853 ) 1854 if flavor in DB_FLAVORS_CAST_DTYPES: 1855 bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ) 1856 val_to_cast = ( 1857 -987654321 1858 if flavor in ('mysql', 'mariadb') 1859 else 0 1860 ) 1861 return f'CAST({val_to_cast} AS {bool_typ})' 1862 if 'time' in typ.lower() or 'date' in typ.lower(): 1863 db_type = typ if typ.isupper() else None 1864 return dateadd_str(flavor=flavor, begin='1900-01-01', db_type=db_type) 1865 if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',): 1866 return '-987654321.0' 1867 if flavor == 'oracle' and typ.lower().split('(', maxsplit=1)[0] == 'char': 1868 return "'-987654321'" 1869 if flavor == 'oracle' and typ.lower() in ('blob', 'bytes'): 1870 return '00' 1871 if typ.lower() in ('uniqueidentifier', 'guid', 'uuid'): 1872 magic_val = 'DEADBEEF-ABBA-BABE-CAFE-DECAFC0FFEE5' 1873 if flavor == 'mssql': 1874 return f"CAST('{magic_val}' AS UNIQUEIDENTIFIER)" 1875 return f"'{magic_val}'" 1876 return ('n' if flavor == 'oracle' else '') + "'-987654321'"
Return a value that may temporarily be used in place of NULL for this type.
Parameters
- typ (str): The typ to be converted to NULL.
- flavor (str): The database flavor for which this value will be used.
Returns
- A value which may stand in place of NULL for this type.
'None'
is returned if a value cannot be determined.
1879def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]: 1880 """ 1881 Fetch the database version if possible. 1882 """ 1883 version_name = sql_item_name('version', conn.flavor, None) 1884 version_query = version_queries.get( 1885 conn.flavor, 1886 version_queries['default'] 1887 ).format(version_name=version_name) 1888 return conn.value(version_query, debug=debug)
Fetch the database version if possible.
1891def get_rename_table_queries( 1892 old_table: str, 1893 new_table: str, 1894 flavor: str, 1895 schema: Optional[str] = None, 1896) -> List[str]: 1897 """ 1898 Return queries to alter a table's name. 1899 1900 Parameters 1901 ---------- 1902 old_table: str 1903 The unquoted name of the old table. 1904 1905 new_table: str 1906 The unquoted name of the new table. 1907 1908 flavor: str 1909 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1910 1911 schema: Optional[str], default None 1912 The schema on which the table resides. 1913 1914 Returns 1915 ------- 1916 A list of `ALTER TABLE` or equivalent queries for the database flavor. 1917 """ 1918 old_table_name = sql_item_name(old_table, flavor, schema) 1919 new_table_name = sql_item_name(new_table, flavor, None) 1920 tmp_table = '_tmp_rename_' + new_table 1921 tmp_table_name = sql_item_name(tmp_table, flavor, schema) 1922 if flavor == 'mssql': 1923 return [f"EXEC sp_rename '{old_table}', '{new_table}'"] 1924 1925 if_exists_str = "IF EXISTS" if flavor in DROP_IF_EXISTS_FLAVORS else "" 1926 if flavor == 'duckdb': 1927 return ( 1928 get_create_table_queries( 1929 f"SELECT * FROM {old_table_name}", 1930 tmp_table, 1931 'duckdb', 1932 schema, 1933 ) + get_create_table_queries( 1934 f"SELECT * FROM {tmp_table_name}", 1935 new_table, 1936 'duckdb', 1937 schema, 1938 ) + [ 1939 f"DROP TABLE {if_exists_str} {tmp_table_name}", 1940 f"DROP TABLE {if_exists_str} {old_table_name}", 1941 ] 1942 ) 1943 1944 return [f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}"]
Return queries to alter a table's name.
Parameters
- old_table (str): The unquoted name of the old table.
- new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
. - schema (Optional[str], default None): The schema on which the table resides.
Returns
- A list of
ALTER TABLE
or equivalent queries for the database flavor.
1947def get_create_table_query( 1948 query_or_dtypes: Union[str, Dict[str, str]], 1949 new_table: str, 1950 flavor: str, 1951 schema: Optional[str] = None, 1952) -> str: 1953 """ 1954 NOTE: This function is deprecated. Use `get_create_table_queries()` instead. 1955 1956 Return a query to create a new table from a `SELECT` query. 1957 1958 Parameters 1959 ---------- 1960 query: Union[str, Dict[str, str]] 1961 The select query to use for the creation of the table. 1962 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 1963 1964 new_table: str 1965 The unquoted name of the new table. 1966 1967 flavor: str 1968 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 1969 1970 schema: Optional[str], default None 1971 The schema on which the table will reside. 1972 1973 Returns 1974 ------- 1975 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1976 """ 1977 return get_create_table_queries( 1978 query_or_dtypes, 1979 new_table, 1980 flavor, 1981 schema=schema, 1982 primary_key=None, 1983 )[0]
NOTE: This function is deprecated. Use get_create_table_queries()
instead.
Return a query to create a new table from a SELECT
query.
Parameters
- query (Union[str, Dict[str, str]]):
The select query to use for the creation of the table.
If a dictionary is provided, return a
CREATE TABLE
query from the givendtypes
columns. - new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
). - schema (Optional[str], default None): The schema on which the table will reside.
Returns
- A
CREATE TABLE
(orSELECT INTO
) query for the database flavor.
1986def get_create_table_queries( 1987 query_or_dtypes: Union[str, Dict[str, str]], 1988 new_table: str, 1989 flavor: str, 1990 schema: Optional[str] = None, 1991 primary_key: Optional[str] = None, 1992 primary_key_db_type: Optional[str] = None, 1993 autoincrement: bool = False, 1994 datetime_column: Optional[str] = None, 1995) -> List[str]: 1996 """ 1997 Return a query to create a new table from a `SELECT` query or a `dtypes` dictionary. 1998 1999 Parameters 2000 ---------- 2001 query_or_dtypes: Union[str, Dict[str, str]] 2002 The select query to use for the creation of the table. 2003 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 2004 2005 new_table: str 2006 The unquoted name of the new table. 2007 2008 flavor: str 2009 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 2010 2011 schema: Optional[str], default None 2012 The schema on which the table will reside. 2013 2014 primary_key: Optional[str], default None 2015 If provided, designate this column as the primary key in the new table. 2016 2017 primary_key_db_type: Optional[str], default None 2018 If provided, alter the primary key to this type (to set NOT NULL constraint). 2019 2020 autoincrement: bool, default False 2021 If `True` and `primary_key` is provided, create the `primary_key` column 2022 as an auto-incrementing integer column. 2023 2024 datetime_column: Optional[str], default None 2025 If provided, include this column in the primary key. 2026 Applicable to TimescaleDB only. 2027 2028 Returns 2029 ------- 2030 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 2031 """ 2032 if not isinstance(query_or_dtypes, (str, dict)): 2033 raise TypeError("`query_or_dtypes` must be a query or a dtypes dictionary.") 2034 2035 method = ( 2036 _get_create_table_query_from_cte 2037 if isinstance(query_or_dtypes, str) 2038 else _get_create_table_query_from_dtypes 2039 ) 2040 return method( 2041 query_or_dtypes, 2042 new_table, 2043 flavor, 2044 schema=schema, 2045 primary_key=primary_key, 2046 primary_key_db_type=primary_key_db_type, 2047 autoincrement=(autoincrement and flavor not in SKIP_AUTO_INCREMENT_FLAVORS), 2048 datetime_column=datetime_column, 2049 )
Return a query to create a new table from a SELECT
query or a dtypes
dictionary.
Parameters
- query_or_dtypes (Union[str, Dict[str, str]]):
The select query to use for the creation of the table.
If a dictionary is provided, return a
CREATE TABLE
query from the givendtypes
columns. - new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
). - schema (Optional[str], default None): The schema on which the table will reside.
- primary_key (Optional[str], default None): If provided, designate this column as the primary key in the new table.
- primary_key_db_type (Optional[str], default None): If provided, alter the primary key to this type (to set NOT NULL constraint).
- autoincrement (bool, default False):
If
True
andprimary_key
is provided, create theprimary_key
column as an auto-incrementing integer column. - datetime_column (Optional[str], default None): If provided, include this column in the primary key. Applicable to TimescaleDB only.
Returns
- A
CREATE TABLE
(orSELECT INTO
) query for the database flavor.
2290def wrap_query_with_cte( 2291 sub_query: str, 2292 parent_query: str, 2293 flavor: str, 2294 cte_name: str = "src", 2295) -> str: 2296 """ 2297 Wrap a subquery in a CTE and append an encapsulating query. 2298 2299 Parameters 2300 ---------- 2301 sub_query: str 2302 The query to be referenced. This may itself contain CTEs. 2303 Unless `cte_name` is provided, this will be aliased as `src`. 2304 2305 parent_query: str 2306 The larger query to append which references the subquery. 2307 This must not contain CTEs. 2308 2309 flavor: str 2310 The database flavor, e.g. `'mssql'`. 2311 2312 cte_name: str, default 'src' 2313 The CTE alias, defaults to `src`. 2314 2315 Returns 2316 ------- 2317 An encapsulating query which allows you to treat `sub_query` as a temporary table. 2318 2319 Examples 2320 -------- 2321 2322 ```python 2323 from meerschaum.utils.sql import wrap_query_with_cte 2324 sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo" 2325 parent_query = "SELECT newval * 3 FROM src" 2326 query = wrap_query_with_cte(sub_query, parent_query, 'mssql') 2327 print(query) 2328 # WITH foo AS (SELECT 1 AS val), 2329 # [src] AS ( 2330 # SELECT (val * 2) AS newval FROM foo 2331 # ) 2332 # SELECT newval * 3 FROM src 2333 ``` 2334 2335 """ 2336 import textwrap 2337 sub_query = sub_query.lstrip() 2338 cte_name_quoted = sql_item_name(cte_name, flavor, None) 2339 2340 if flavor in NO_CTE_FLAVORS: 2341 return ( 2342 parent_query 2343 .replace(cte_name_quoted, '--MRSM_SUBQUERY--') 2344 .replace(cte_name, '--MRSM_SUBQUERY--') 2345 .replace('--MRSM_SUBQUERY--', f"(\n{sub_query}\n) AS {cte_name_quoted}") 2346 ) 2347 2348 if sub_query.lstrip().lower().startswith('with '): 2349 final_select_ix = sub_query.lower().rfind('select') 2350 return ( 2351 sub_query[:final_select_ix].rstrip() + ',\n' 2352 + f"{cte_name_quoted} AS (\n" 2353 + ' ' + sub_query[final_select_ix:] 2354 + "\n)\n" 2355 + parent_query 2356 ) 2357 2358 return ( 2359 f"WITH {cte_name_quoted} AS (\n" 2360 f"{textwrap.indent(sub_query, ' ')}\n" 2361 f")\n{parent_query}" 2362 )
Wrap a subquery in a CTE and append an encapsulating query.
Parameters
- sub_query (str):
The query to be referenced. This may itself contain CTEs.
Unless
cte_name
is provided, this will be aliased assrc
. - parent_query (str): The larger query to append which references the subquery. This must not contain CTEs.
- flavor (str):
The database flavor, e.g.
'mssql'
. - cte_name (str, default 'src'):
The CTE alias, defaults to
src
.
Returns
- An encapsulating query which allows you to treat
sub_query
as a temporary table.
Examples
from meerschaum.utils.sql import wrap_query_with_cte
sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo"
parent_query = "SELECT newval * 3 FROM src"
query = wrap_query_with_cte(sub_query, parent_query, 'mssql')
print(query)
# WITH foo AS (SELECT 1 AS val),
# [src] AS (
# SELECT (val * 2) AS newval FROM foo
# )
# SELECT newval * 3 FROM src
2365def format_cte_subquery( 2366 sub_query: str, 2367 flavor: str, 2368 sub_name: str = 'src', 2369 cols_to_select: Union[List[str], str] = '*', 2370) -> str: 2371 """ 2372 Given a subquery, build a wrapper query that selects from the CTE subquery. 2373 2374 Parameters 2375 ---------- 2376 sub_query: str 2377 The subquery to wrap. 2378 2379 flavor: str 2380 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 2381 2382 sub_name: str, default 'src' 2383 If possible, give this name to the CTE (must be unquoted). 2384 2385 cols_to_select: Union[List[str], str], default '' 2386 If specified, choose which columns to select from the CTE. 2387 If a list of strings is provided, each item will be quoted and joined with commas. 2388 If a string is given, assume it is quoted and insert it into the query. 2389 2390 Returns 2391 ------- 2392 A wrapper query that selects from the CTE. 2393 """ 2394 quoted_sub_name = sql_item_name(sub_name, flavor, None) 2395 cols_str = ( 2396 cols_to_select 2397 if isinstance(cols_to_select, str) 2398 else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select]) 2399 ) 2400 parent_query = ( 2401 f"SELECT {cols_str}\n" 2402 f"FROM {quoted_sub_name}" 2403 ) 2404 return wrap_query_with_cte(sub_query, parent_query, flavor, cte_name=sub_name)
Given a subquery, build a wrapper query that selects from the CTE subquery.
Parameters
- sub_query (str): The subquery to wrap.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
. - sub_name (str, default 'src'): If possible, give this name to the CTE (must be unquoted).
- cols_to_select (Union[List[str], str], default ''): If specified, choose which columns to select from the CTE. If a list of strings is provided, each item will be quoted and joined with commas. If a string is given, assume it is quoted and insert it into the query.
Returns
- A wrapper query that selects from the CTE.
2407def session_execute( 2408 session: 'sqlalchemy.orm.session.Session', 2409 queries: Union[List[str], str], 2410 with_results: bool = False, 2411 debug: bool = False, 2412) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]: 2413 """ 2414 Similar to `SQLConnector.exec_queries()`, execute a list of queries 2415 and roll back when one fails. 2416 2417 Parameters 2418 ---------- 2419 session: sqlalchemy.orm.session.Session 2420 A SQLAlchemy session representing a transaction. 2421 2422 queries: Union[List[str], str] 2423 A query or list of queries to be executed. 2424 If a query fails, roll back the session. 2425 2426 with_results: bool, default False 2427 If `True`, return a list of result objects. 2428 2429 Returns 2430 ------- 2431 A `SuccessTuple` indicating the queries were successfully executed. 2432 If `with_results`, return the `SuccessTuple` and a list of results. 2433 """ 2434 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 2435 if not isinstance(queries, list): 2436 queries = [queries] 2437 successes, msgs, results = [], [], [] 2438 for query in queries: 2439 if debug: 2440 dprint(query) 2441 query_text = sqlalchemy.text(query) 2442 fail_msg = "Failed to execute queries." 2443 try: 2444 result = session.execute(query_text) 2445 query_success = result is not None 2446 query_msg = "Success" if query_success else fail_msg 2447 except Exception as e: 2448 query_success = False 2449 query_msg = f"{fail_msg}\n{e}" 2450 result = None 2451 successes.append(query_success) 2452 msgs.append(query_msg) 2453 results.append(result) 2454 if not query_success: 2455 if debug: 2456 dprint("Rolling back session.") 2457 session.rollback() 2458 break 2459 success, msg = all(successes), '\n'.join(msgs) 2460 if with_results: 2461 return (success, msg), results 2462 return success, msg
Similar to SQLConnector.exec_queries()
, execute a list of queries
and roll back when one fails.
Parameters
- session (sqlalchemy.orm.session.Session): A SQLAlchemy session representing a transaction.
- queries (Union[List[str], str]): A query or list of queries to be executed. If a query fails, roll back the session.
- with_results (bool, default False):
If
True
, return a list of result objects.
Returns
- A
SuccessTuple
indicating the queries were successfully executed. - If
with_results
, return theSuccessTuple
and a list of results.
2465def get_reset_autoincrement_queries( 2466 table: str, 2467 column: str, 2468 connector: mrsm.connectors.SQLConnector, 2469 schema: Optional[str] = None, 2470 debug: bool = False, 2471) -> List[str]: 2472 """ 2473 Return a list of queries to reset a table's auto-increment counter to the next largest value. 2474 2475 Parameters 2476 ---------- 2477 table: str 2478 The name of the table on which the auto-incrementing column exists. 2479 2480 column: str 2481 The name of the auto-incrementing column. 2482 2483 connector: mrsm.connectors.SQLConnector 2484 The SQLConnector to the database on which the table exists. 2485 2486 schema: Optional[str], default None 2487 The schema of the table. Defaults to `connector.schema`. 2488 2489 Returns 2490 ------- 2491 A list of queries to be executed to reset the auto-incrementing column. 2492 """ 2493 if not table_exists(table, connector, schema=schema, debug=debug): 2494 return [] 2495 2496 schema = schema or connector.schema 2497 max_id_name = sql_item_name('max_id', connector.flavor) 2498 table_name = sql_item_name(table, connector.flavor, schema) 2499 table_seq_name = sql_item_name(table + '_' + column + '_seq', connector.flavor, schema) 2500 column_name = sql_item_name(column, connector.flavor) 2501 max_id = connector.value( 2502 f""" 2503 SELECT COALESCE(MAX({column_name}), 0) AS {max_id_name} 2504 FROM {table_name} 2505 """, 2506 debug=debug, 2507 ) 2508 if max_id is None: 2509 return [] 2510 2511 reset_queries = reset_autoincrement_queries.get( 2512 connector.flavor, 2513 reset_autoincrement_queries['default'] 2514 ) 2515 if not isinstance(reset_queries, list): 2516 reset_queries = [reset_queries] 2517 2518 return [ 2519 query.format( 2520 column=column, 2521 column_name=column_name, 2522 table=table, 2523 table_name=table_name, 2524 table_seq_name=table_seq_name, 2525 val=max_id, 2526 val_plus_1=(max_id + 1), 2527 ) 2528 for query in reset_queries 2529 ]
Return a list of queries to reset a table's auto-increment counter to the next largest value.
Parameters
- table (str): The name of the table on which the auto-incrementing column exists.
- column (str): The name of the auto-incrementing column.
- connector (mrsm.connectors.SQLConnector): The SQLConnector to the database on which the table exists.
- schema (Optional[str], default None):
The schema of the table. Defaults to
connector.schema
.
Returns
- A list of queries to be executed to reset the auto-incrementing column.
2532def get_postgis_geo_columns_types( 2533 connectable: Union[ 2534 'mrsm.connectors.sql.SQLConnector', 2535 'sqlalchemy.orm.session.Session', 2536 'sqlalchemy.engine.base.Engine' 2537 ], 2538 table: str, 2539 schema: Optional[str] = 'public', 2540 debug: bool = False, 2541) -> Dict[str, str]: 2542 """ 2543 Return a dictionary mapping PostGIS geometry column names to geometry types. 2544 """ 2545 from meerschaum.utils.dtypes import get_geometry_type_srid 2546 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 2547 default_type, default_srid = get_geometry_type_srid() 2548 default_type = default_type.upper() 2549 2550 clean(table) 2551 clean(str(schema)) 2552 schema = schema or 'public' 2553 truncated_schema_name = truncate_item_name(schema, flavor='postgis') 2554 truncated_table_name = truncate_item_name(table, flavor='postgis') 2555 query = sqlalchemy.text( 2556 "SELECT \"f_geometry_column\" AS \"column\", 'GEOMETRY' AS \"func\", \"type\", \"srid\"\n" 2557 "FROM \"geometry_columns\"\n" 2558 f"WHERE \"f_table_schema\" = '{truncated_schema_name}'\n" 2559 f" AND \"f_table_name\" = '{truncated_table_name}'\n" 2560 "UNION ALL\n" 2561 "SELECT \"f_geography_column\" AS \"column\", 'GEOGRAPHY' AS \"func\", \"type\", \"srid\"\n" 2562 "FROM \"geography_columns\"\n" 2563 f"WHERE \"f_table_schema\" = '{truncated_schema_name}'\n" 2564 f" AND \"f_table_name\" = '{truncated_table_name}'\n" 2565 ) 2566 debug_kwargs = {'debug': debug} if isinstance(connectable, mrsm.connectors.SQLConnector) else {} 2567 result_rows = [ 2568 row 2569 for row in connectable.execute(query, **debug_kwargs).fetchall() 2570 ] 2571 cols_type_tuples = { 2572 row[0]: (row[1], row[2], row[3]) 2573 for row in result_rows 2574 } 2575 2576 geometry_cols_types = { 2577 col: ( 2578 f"{func}({typ.upper()}, {srid})" 2579 if srid 2580 else ( 2581 func 2582 + ( 2583 f'({typ.upper()})' 2584 if typ.upper() not in ('GEOMETRY', 'GEOGRAPHY') 2585 else '' 2586 ) 2587 ) 2588 ) 2589 for col, (func, typ, srid) in cols_type_tuples.items() 2590 } 2591 return geometry_cols_types
Return a dictionary mapping PostGIS geometry column names to geometry types.
2594def get_create_schema_if_not_exists_queries( 2595 schema: str, 2596 flavor: str, 2597) -> List[str]: 2598 """ 2599 Return the queries to create a schema if it does not yet exist. 2600 For databases which do not support schemas, an empty list will be returned. 2601 """ 2602 if not schema: 2603 return [] 2604 2605 if flavor in NO_SCHEMA_FLAVORS: 2606 return [] 2607 2608 if schema == DEFAULT_SCHEMA_FLAVORS.get(flavor, None): 2609 return [] 2610 2611 clean(schema) 2612 2613 if flavor == 'mssql': 2614 return [ 2615 ( 2616 f"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{schema}')\n" 2617 "BEGIN\n" 2618 f" EXEC('CREATE SCHEMA {sql_item_name(schema, flavor)}');\n" 2619 "END;" 2620 ) 2621 ] 2622 2623 if flavor == 'oracle': 2624 return [] 2625 2626 return [ 2627 f"CREATE SCHEMA IF NOT EXISTS {sql_item_name(schema, flavor)};" 2628 ]
Return the queries to create a schema if it does not yet exist. For databases which do not support schemas, an empty list will be returned.