meerschaum.utils.sql
Flavor-specific SQL tools.
1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# vim:fenc=utf-8 4 5""" 6Flavor-specific SQL tools. 7""" 8 9from __future__ import annotations 10from datetime import datetime, timezone, timedelta 11import meerschaum as mrsm 12from meerschaum.utils.typing import Optional, Dict, Any, Union, List, Iterable, Tuple 13### Preserve legacy imports. 14from meerschaum.utils.dtypes.sql import ( 15 DB_TO_PD_DTYPES, 16 PD_TO_DB_DTYPES_FLAVORS, 17 get_pd_type_from_db_type as get_pd_type, 18 get_db_type_from_pd_type as get_db_type, 19 TIMEZONE_NAIVE_FLAVORS, 20) 21from meerschaum.utils.warnings import warn 22from meerschaum.utils.debug import dprint 23 24test_queries = { 25 'default' : 'SELECT 1', 26 'oracle' : 'SELECT 1 FROM DUAL', 27 'informix' : 'SELECT COUNT(*) FROM systables', 28 'hsqldb' : 'SELECT 1 FROM INFORMATION_SCHEMA.SYSTEM_USERS', 29} 30### `table_name` is the escaped name of the table. 31### `table` is the unescaped name of the table. 32exists_queries = { 33 'default': "SELECT COUNT(*) FROM {table_name} WHERE 1 = 0", 34} 35version_queries = { 36 'default': "SELECT VERSION() AS {version_name}", 37 'sqlite': "SELECT SQLITE_VERSION() AS {version_name}", 38 'mssql': "SELECT @@version", 39 'oracle': "SELECT version from PRODUCT_COMPONENT_VERSION WHERE rownum = 1", 40} 41SKIP_IF_EXISTS_FLAVORS = {'mssql', 'oracle'} 42DROP_IF_EXISTS_FLAVORS = { 43 'timescaledb', 'postgresql', 'citus', 'mssql', 'mysql', 'mariadb', 'sqlite', 44} 45SKIP_AUTO_INCREMENT_FLAVORS = {'citus', 'duckdb'} 46COALESCE_UNIQUE_INDEX_FLAVORS = {'timescaledb', 'postgresql', 'citus'} 47update_queries = { 48 'default': """ 49 UPDATE {target_table_name} AS f 50 {sets_subquery_none} 51 FROM {target_table_name} AS t 52 INNER JOIN (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) AS p 53 ON {and_subquery_t} 54 WHERE 55 {and_subquery_f} 56 AND {date_bounds_subquery} 57 """, 58 'timescaledb-upsert': """ 59 INSERT INTO {target_table_name} ({patch_cols_str}) 60 SELECT {patch_cols_str} 61 FROM {patch_table_name} 62 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 63 """, 64 'postgresql-upsert': """ 65 INSERT INTO {target_table_name} ({patch_cols_str}) 66 SELECT {patch_cols_str} 67 FROM {patch_table_name} 68 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 69 """, 70 'citus-upsert': """ 71 INSERT INTO {target_table_name} ({patch_cols_str}) 72 SELECT {patch_cols_str} 73 FROM {patch_table_name} 74 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 75 """, 76 'cockroachdb-upsert': """ 77 INSERT INTO {target_table_name} ({patch_cols_str}) 78 SELECT {patch_cols_str} 79 FROM {patch_table_name} 80 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 81 """, 82 'mysql': """ 83 UPDATE {target_table_name} AS f 84 JOIN (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) AS p 85 ON {and_subquery_f} 86 {sets_subquery_f} 87 WHERE {date_bounds_subquery} 88 """, 89 'mysql-upsert': """ 90 INSERT {ignore}INTO {target_table_name} ({patch_cols_str}) 91 SELECT {patch_cols_str} 92 FROM {patch_table_name} 93 {on_duplicate_key_update} 94 {cols_equal_values} 95 """, 96 'mariadb': """ 97 UPDATE {target_table_name} AS f 98 JOIN (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) AS p 99 ON {and_subquery_f} 100 {sets_subquery_f} 101 WHERE {date_bounds_subquery} 102 """, 103 'mariadb-upsert': """ 104 INSERT {ignore}INTO {target_table_name} ({patch_cols_str}) 105 SELECT {patch_cols_str} 106 FROM {patch_table_name} 107 {on_duplicate_key_update} 108 {cols_equal_values} 109 """, 110 'mssql': """ 111 MERGE {target_table_name} f 112 USING (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) p 113 ON {and_subquery_f} 114 AND {date_bounds_subquery} 115 WHEN MATCHED THEN 116 UPDATE 117 {sets_subquery_none}; 118 """, 119 'mssql-upsert': """ 120 MERGE {target_table_name} f 121 USING (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) p 122 ON {and_subquery_f} 123 AND {date_bounds_subquery} 124 {when_matched_update_sets_subquery_none} 125 WHEN NOT MATCHED THEN 126 INSERT ({patch_cols_str}) 127 VALUES ({patch_cols_prefixed_str}); 128 """, 129 'oracle': """ 130 MERGE INTO {target_table_name} f 131 USING (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) p 132 ON ( 133 {and_subquery_f} 134 AND {date_bounds_subquery} 135 ) 136 WHEN MATCHED THEN 137 UPDATE 138 {sets_subquery_none} 139 """, 140 'sqlite-upsert': """ 141 INSERT INTO {target_table_name} ({patch_cols_str}) 142 SELECT {patch_cols_str} 143 FROM {patch_table_name} 144 WHERE true 145 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 146 """, 147 'sqlite_delete_insert': [ 148 """ 149 DELETE FROM {target_table_name} AS f 150 WHERE ROWID IN ( 151 SELECT t.ROWID 152 FROM {target_table_name} AS t 153 INNER JOIN (SELECT DISTINCT * FROM {patch_table_name}) AS p 154 ON {and_subquery_t} 155 ); 156 """, 157 """ 158 INSERT INTO {target_table_name} AS f 159 SELECT DISTINCT {patch_cols_str} FROM {patch_table_name} AS p 160 """, 161 ], 162} 163columns_types_queries = { 164 'default': """ 165 SELECT 166 table_catalog AS database, 167 table_schema AS schema, 168 table_name AS table, 169 column_name AS column, 170 data_type AS type 171 FROM information_schema.columns 172 WHERE table_name IN ('{table}', '{table_trunc}') 173 """, 174 'sqlite': """ 175 SELECT 176 '' "database", 177 '' "schema", 178 m.name "table", 179 p.name "column", 180 p.type "type" 181 FROM sqlite_master m 182 LEFT OUTER JOIN pragma_table_info(m.name) p 183 ON m.name <> p.name 184 WHERE m.type = 'table' 185 AND m.name IN ('{table}', '{table_trunc}') 186 """, 187 'mssql': """ 188 SELECT 189 TABLE_CATALOG AS [database], 190 TABLE_SCHEMA AS [schema], 191 TABLE_NAME AS [table], 192 COLUMN_NAME AS [column], 193 DATA_TYPE AS [type] 194 FROM {db_prefix}INFORMATION_SCHEMA.COLUMNS 195 WHERE TABLE_NAME IN ( 196 '{table}', 197 '{table_trunc}' 198 ) 199 200 """, 201 'mysql': """ 202 SELECT 203 TABLE_SCHEMA `database`, 204 TABLE_SCHEMA `schema`, 205 TABLE_NAME `table`, 206 COLUMN_NAME `column`, 207 DATA_TYPE `type` 208 FROM INFORMATION_SCHEMA.COLUMNS 209 WHERE TABLE_NAME IN ('{table}', '{table_trunc}') 210 """, 211 'mariadb': """ 212 SELECT 213 TABLE_SCHEMA `database`, 214 TABLE_SCHEMA `schema`, 215 TABLE_NAME `table`, 216 COLUMN_NAME `column`, 217 DATA_TYPE `type` 218 FROM INFORMATION_SCHEMA.COLUMNS 219 WHERE TABLE_NAME IN ('{table}', '{table_trunc}') 220 """, 221 'oracle': """ 222 SELECT 223 NULL AS "database", 224 NULL AS "schema", 225 TABLE_NAME AS "table", 226 COLUMN_NAME AS "column", 227 DATA_TYPE AS "type" 228 FROM all_tab_columns 229 WHERE TABLE_NAME IN ( 230 '{table}', 231 '{table_trunc}', 232 '{table_lower}', 233 '{table_lower_trunc}', 234 '{table_upper}', 235 '{table_upper_trunc}' 236 ) 237 """, 238} 239hypertable_queries = { 240 'timescaledb': 'SELECT hypertable_size(\'{table_name}\')', 241 'citus': 'SELECT citus_table_size(\'{table_name}\')', 242} 243columns_indices_queries = { 244 'default': """ 245 SELECT 246 current_database() AS "database", 247 n.nspname AS "schema", 248 t.relname AS "table", 249 c.column_name AS "column", 250 i.relname AS "index", 251 CASE WHEN con.contype = 'p' THEN 'PRIMARY KEY' ELSE 'INDEX' END AS "index_type" 252 FROM pg_class t 253 INNER JOIN pg_index AS ix 254 ON t.oid = ix.indrelid 255 INNER JOIN pg_class AS i 256 ON i.oid = ix.indexrelid 257 INNER JOIN pg_namespace AS n 258 ON n.oid = t.relnamespace 259 INNER JOIN pg_attribute AS a 260 ON a.attnum = ANY(ix.indkey) 261 AND a.attrelid = t.oid 262 INNER JOIN information_schema.columns AS c 263 ON c.column_name = a.attname 264 AND c.table_name = t.relname 265 AND c.table_schema = n.nspname 266 LEFT JOIN pg_constraint AS con 267 ON con.conindid = i.oid 268 AND con.contype = 'p' 269 WHERE 270 t.relname IN ('{table}', '{table_trunc}') 271 AND n.nspname = '{schema}' 272 """, 273 'sqlite': """ 274 WITH indexed_columns AS ( 275 SELECT 276 '{table}' AS table_name, 277 pi.name AS column_name, 278 i.name AS index_name, 279 'INDEX' AS index_type 280 FROM 281 sqlite_master AS i, 282 pragma_index_info(i.name) AS pi 283 WHERE 284 i.type = 'index' 285 AND i.tbl_name = '{table}' 286 ), 287 primary_key_columns AS ( 288 SELECT 289 '{table}' AS table_name, 290 ti.name AS column_name, 291 'PRIMARY_KEY' AS index_name, 292 'PRIMARY KEY' AS index_type 293 FROM 294 pragma_table_info('{table}') AS ti 295 WHERE 296 ti.pk > 0 297 ) 298 SELECT 299 NULL AS "database", 300 NULL AS "schema", 301 "table_name" AS "table", 302 "column_name" AS "column", 303 "index_name" AS "index", 304 "index_type" 305 FROM indexed_columns 306 UNION ALL 307 SELECT 308 NULL AS "database", 309 NULL AS "schema", 310 table_name AS "table", 311 column_name AS "column", 312 index_name AS "index", 313 index_type 314 FROM primary_key_columns 315 """, 316 'mssql': """ 317 SELECT 318 NULL AS [database], 319 s.name AS [schema], 320 t.name AS [table], 321 c.name AS [column], 322 i.name AS [index], 323 CASE 324 WHEN kc.type = 'PK' THEN 'PRIMARY KEY' 325 ELSE 'INDEX' 326 END AS [index_type] 327 FROM 328 sys.schemas s 329 INNER JOIN sys.tables t 330 ON s.schema_id = t.schema_id 331 INNER JOIN sys.indexes i 332 ON t.object_id = i.object_id 333 INNER JOIN sys.index_columns ic 334 ON i.object_id = ic.object_id 335 AND i.index_id = ic.index_id 336 INNER JOIN sys.columns c 337 ON ic.object_id = c.object_id 338 AND ic.column_id = c.column_id 339 LEFT JOIN sys.key_constraints kc 340 ON kc.parent_object_id = i.object_id 341 AND kc.type = 'PK' 342 AND kc.name = i.name 343 WHERE 344 t.name IN ('{table}', '{table_trunc}') 345 AND s.name = '{schema}' 346 AND i.type IN (1, 2) -- 1 = CLUSTERED, 2 = NONCLUSTERED 347 """, 348 'oracle': """ 349 SELECT 350 NULL AS "database", 351 ic.table_owner AS "schema", 352 ic.table_name AS "table", 353 ic.column_name AS "column", 354 i.index_name AS "index", 355 CASE 356 WHEN c.constraint_type = 'P' THEN 'PRIMARY KEY' 357 WHEN i.uniqueness = 'UNIQUE' THEN 'UNIQUE INDEX' 358 ELSE 'INDEX' 359 END AS index_type 360 FROM 361 all_ind_columns ic 362 INNER JOIN all_indexes i 363 ON ic.index_name = i.index_name 364 AND ic.table_owner = i.owner 365 LEFT JOIN all_constraints c 366 ON i.index_name = c.constraint_name 367 AND i.table_owner = c.owner 368 AND c.constraint_type = 'P' 369 WHERE ic.table_name IN ( 370 '{table}', 371 '{table_trunc}', 372 '{table_upper}', 373 '{table_upper_trunc}' 374 ) 375 """, 376 'mysql': """ 377 SELECT 378 TABLE_SCHEMA AS `database`, 379 TABLE_SCHEMA AS `schema`, 380 TABLE_NAME AS `table`, 381 COLUMN_NAME AS `column`, 382 INDEX_NAME AS `index`, 383 CASE 384 WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY' 385 ELSE 'INDEX' 386 END AS `index_type` 387 FROM 388 information_schema.STATISTICS 389 WHERE 390 TABLE_NAME IN ('{table}', '{table_trunc}') 391 """, 392 'mariadb': """ 393 SELECT 394 TABLE_SCHEMA AS `database`, 395 TABLE_SCHEMA AS `schema`, 396 TABLE_NAME AS `table`, 397 COLUMN_NAME AS `column`, 398 INDEX_NAME AS `index`, 399 CASE 400 WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY' 401 ELSE 'INDEX' 402 END AS `index_type` 403 FROM 404 information_schema.STATISTICS 405 WHERE 406 TABLE_NAME IN ('{table}', '{table_trunc}') 407 """, 408} 409reset_autoincrement_queries: Dict[str, Union[str, List[str]]] = { 410 'default': """ 411 SELECT SETVAL(pg_get_serial_sequence('{table}', '{column}'), {val}) 412 FROM {table_name} 413 """, 414 'mssql': """ 415 DBCC CHECKIDENT ('{table}', RESEED, {val}) 416 """, 417 'mysql': """ 418 ALTER TABLE {table_name} AUTO_INCREMENT = {val} 419 """, 420 'mariadb': """ 421 ALTER TABLE {table_name} AUTO_INCREMENT = {val} 422 """, 423 'sqlite': """ 424 UPDATE sqlite_sequence 425 SET seq = {val} 426 WHERE name = '{table}' 427 """, 428 'oracle': [ 429 """ 430 DECLARE 431 max_id NUMBER := {val}; 432 current_val NUMBER; 433 BEGIN 434 SELECT {table_seq_name}.NEXTVAL INTO current_val FROM dual; 435 436 WHILE current_val < max_id LOOP 437 SELECT {table_seq_name}.NEXTVAL INTO current_val FROM dual; 438 END LOOP; 439 END; 440 """, 441 ], 442} 443table_wrappers = { 444 'default' : ('"', '"'), 445 'timescaledb': ('"', '"'), 446 'citus' : ('"', '"'), 447 'duckdb' : ('"', '"'), 448 'postgresql' : ('"', '"'), 449 'sqlite' : ('"', '"'), 450 'mysql' : ('`', '`'), 451 'mariadb' : ('`', '`'), 452 'mssql' : ('[', ']'), 453 'cockroachdb': ('"', '"'), 454 'oracle' : ('"', '"'), 455} 456max_name_lens = { 457 'default' : 64, 458 'mssql' : 128, 459 'oracle' : 30, 460 'postgresql' : 64, 461 'timescaledb': 64, 462 'citus' : 64, 463 'cockroachdb': 64, 464 'sqlite' : 1024, ### Probably more, but 1024 seems more than reasonable. 465 'mysql' : 64, 466 'mariadb' : 64, 467} 468json_flavors = {'postgresql', 'timescaledb', 'citus', 'cockroachdb'} 469NO_SCHEMA_FLAVORS = {'oracle', 'sqlite', 'mysql', 'mariadb', 'duckdb'} 470DEFAULT_SCHEMA_FLAVORS = { 471 'postgresql': 'public', 472 'timescaledb': 'public', 473 'citus': 'public', 474 'cockroachdb': 'public', 475 'mysql': 'mysql', 476 'mariadb': 'mysql', 477 'mssql': 'dbo', 478} 479OMIT_NULLSFIRST_FLAVORS = {'mariadb', 'mysql', 'mssql'} 480 481SINGLE_ALTER_TABLE_FLAVORS = {'duckdb', 'sqlite', 'mssql', 'oracle'} 482NO_CTE_FLAVORS = {'mysql', 'mariadb'} 483NO_SELECT_INTO_FLAVORS = {'sqlite', 'oracle', 'mysql', 'mariadb', 'duckdb'} 484 485 486def clean(substring: str) -> str: 487 """ 488 Ensure a substring is clean enough to be inserted into a SQL query. 489 Raises an exception when banned words are used. 490 """ 491 from meerschaum.utils.warnings import error 492 banned_symbols = [';', '--', 'drop ',] 493 for symbol in banned_symbols: 494 if symbol in str(substring).lower(): 495 error(f"Invalid string: '{substring}'") 496 497 498def dateadd_str( 499 flavor: str = 'postgresql', 500 datepart: str = 'day', 501 number: Union[int, float] = 0, 502 begin: Union[str, datetime, int] = 'now' 503) -> str: 504 """ 505 Generate a `DATEADD` clause depending on database flavor. 506 507 Parameters 508 ---------- 509 flavor: str, default `'postgresql'` 510 SQL database flavor, e.g. `'postgresql'`, `'sqlite'`. 511 512 Currently supported flavors: 513 514 - `'postgresql'` 515 - `'timescaledb'` 516 - `'citus'` 517 - `'cockroachdb'` 518 - `'duckdb'` 519 - `'mssql'` 520 - `'mysql'` 521 - `'mariadb'` 522 - `'sqlite'` 523 - `'oracle'` 524 525 datepart: str, default `'day'` 526 Which part of the date to modify. Supported values: 527 528 - `'year'` 529 - `'month'` 530 - `'day'` 531 - `'hour'` 532 - `'minute'` 533 - `'second'` 534 535 number: Union[int, float], default `0` 536 How many units to add to the date part. 537 538 begin: Union[str, datetime], default `'now'` 539 Base datetime to which to add dateparts. 540 541 Returns 542 ------- 543 The appropriate `DATEADD` string for the corresponding database flavor. 544 545 Examples 546 -------- 547 >>> dateadd_str( 548 ... flavor = 'mssql', 549 ... begin = datetime(2022, 1, 1, 0, 0), 550 ... number = 1, 551 ... ) 552 "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME))" 553 >>> dateadd_str( 554 ... flavor = 'postgresql', 555 ... begin = datetime(2022, 1, 1, 0, 0), 556 ... number = 1, 557 ... ) 558 "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'" 559 560 """ 561 from meerschaum.utils.packages import attempt_import 562 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type 563 dateutil_parser = attempt_import('dateutil.parser') 564 if 'int' in str(type(begin)).lower(): 565 return str(begin) 566 if not begin: 567 return '' 568 569 _original_begin = begin 570 begin_time = None 571 ### Sanity check: make sure `begin` is a valid datetime before we inject anything. 572 if not isinstance(begin, datetime): 573 try: 574 begin_time = dateutil_parser.parse(begin) 575 except Exception: 576 begin_time = None 577 else: 578 begin_time = begin 579 580 ### Unable to parse into a datetime. 581 if begin_time is None: 582 ### Throw an error if banned symbols are included in the `begin` string. 583 clean(str(begin)) 584 ### If begin is a valid datetime, wrap it in quotes. 585 else: 586 if isinstance(begin, datetime) and begin.tzinfo is not None: 587 begin = begin.astimezone(timezone.utc) 588 begin = ( 589 f"'{begin.replace(tzinfo=None)}'" 590 if isinstance(begin, datetime) and flavor in TIMEZONE_NAIVE_FLAVORS 591 else f"'{begin}'" 592 ) 593 594 dt_is_utc = begin_time.tzinfo is not None if begin_time is not None else '+' in str(begin) 595 db_type = get_db_type_from_pd_type( 596 ('datetime64[ns, UTC]' if dt_is_utc else 'datetime64[ns]'), 597 flavor=flavor, 598 ) 599 600 da = "" 601 if flavor in ('postgresql', 'timescaledb', 'cockroachdb', 'citus'): 602 begin = ( 603 f"CAST({begin} AS {db_type})" if begin != 'now' 604 else "CAST(NOW() AT TIME ZONE 'utc' AS {db_type})" 605 ) 606 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 607 608 elif flavor == 'duckdb': 609 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'NOW()' 610 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 611 612 elif flavor in ('mssql',): 613 if begin_time and begin_time.microsecond != 0 and not dt_is_utc: 614 begin = begin[:-4] + "'" 615 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'GETUTCDATE()' 616 da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin 617 618 elif flavor in ('mysql', 'mariadb'): 619 begin = f"CAST({begin} AS DATETIME(6))" if begin != 'now' else 'UTC_TIMESTAMP(6)' 620 da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin) 621 622 elif flavor == 'sqlite': 623 da = f"datetime({begin}, '{number} {datepart}')" 624 625 elif flavor == 'oracle': 626 if begin == 'now': 627 begin = str( 628 datetime.now(timezone.utc).replace(tzinfo=None).strftime('%Y:%m:%d %M:%S.%f') 629 ) 630 elif begin_time: 631 begin = str(begin_time.strftime('%Y-%m-%d %H:%M:%S.%f')) 632 dt_format = 'YYYY-MM-DD HH24:MI:SS.FF' 633 _begin = f"'{begin}'" if begin_time else begin 634 da = ( 635 (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin) 636 + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "") 637 ) 638 return da 639 640 641def test_connection( 642 self, 643 **kw: Any 644) -> Union[bool, None]: 645 """ 646 Test if a successful connection to the database may be made. 647 648 Parameters 649 ---------- 650 **kw: 651 The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`. 652 653 Returns 654 ------- 655 `True` if a connection is made, otherwise `False` or `None` in case of failure. 656 657 """ 658 import warnings 659 from meerschaum.connectors.poll import retry_connect 660 _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self} 661 _default_kw.update(kw) 662 with warnings.catch_warnings(): 663 warnings.filterwarnings('ignore', 'Could not') 664 try: 665 return retry_connect(**_default_kw) 666 except Exception as e: 667 return False 668 669 670def get_distinct_col_count( 671 col: str, 672 query: str, 673 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 674 debug: bool = False 675) -> Optional[int]: 676 """ 677 Returns the number of distinct items in a column of a SQL query. 678 679 Parameters 680 ---------- 681 col: str: 682 The column in the query to count. 683 684 query: str: 685 The SQL query to count from. 686 687 connector: Optional[mrsm.connectors.sql.SQLConnector], default None: 688 The SQLConnector to execute the query. 689 690 debug: bool, default False: 691 Verbosity toggle. 692 693 Returns 694 ------- 695 An `int` of the number of columns in the query or `None` if the query fails. 696 697 """ 698 if connector is None: 699 connector = mrsm.get_connector('sql') 700 701 _col_name = sql_item_name(col, connector.flavor, None) 702 703 _meta_query = ( 704 f""" 705 WITH src AS ( {query} ), 706 dist AS ( SELECT DISTINCT {_col_name} FROM src ) 707 SELECT COUNT(*) FROM dist""" 708 ) if connector.flavor not in ('mysql', 'mariadb') else ( 709 f""" 710 SELECT COUNT(*) 711 FROM ( 712 SELECT DISTINCT {_col_name} 713 FROM ({query}) AS src 714 ) AS dist""" 715 ) 716 717 result = connector.value(_meta_query, debug=debug) 718 try: 719 return int(result) 720 except Exception as e: 721 return None 722 723 724def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str: 725 """ 726 Parse SQL items depending on the flavor. 727 728 Parameters 729 ---------- 730 item: str : 731 The database item (table, view, etc.) in need of quotes. 732 733 flavor: str : 734 The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.). 735 736 Returns 737 ------- 738 A `str` which contains the input `item` wrapped in the corresponding escape characters. 739 740 Examples 741 -------- 742 >>> sql_item_name('table', 'sqlite') 743 '"table"' 744 >>> sql_item_name('table', 'mssql') 745 "[table]" 746 >>> sql_item_name('table', 'postgresql', schema='abc') 747 '"abc"."table"' 748 749 """ 750 truncated_item = truncate_item_name(str(item), flavor) 751 if flavor == 'oracle': 752 truncated_item = pg_capital(truncated_item) 753 ### NOTE: System-reserved words must be quoted. 754 if truncated_item.lower() in ( 755 'float', 'varchar', 'nvarchar', 'clob', 756 'boolean', 'integer', 'table', 757 ): 758 wrappers = ('"', '"') 759 else: 760 wrappers = ('', '') 761 else: 762 wrappers = table_wrappers.get(flavor, table_wrappers['default']) 763 764 ### NOTE: SQLite does not support schemas. 765 if flavor == 'sqlite': 766 schema = None 767 768 schema_prefix = ( 769 (wrappers[0] + schema + wrappers[1] + '.') 770 if schema is not None 771 else '' 772 ) 773 774 return schema_prefix + wrappers[0] + truncated_item + wrappers[1] 775 776 777def pg_capital(s: str) -> str: 778 """ 779 If string contains a capital letter, wrap it in double quotes. 780 781 Parameters 782 ---------- 783 s: str : 784 The string to be escaped. 785 786 Returns 787 ------- 788 The input string wrapped in quotes only if it needs them. 789 790 Examples 791 -------- 792 >>> pg_capital("My Table") 793 '"My Table"' 794 >>> pg_capital('my_table') 795 'my_table' 796 797 """ 798 if '"' in s: 799 return s 800 needs_quotes = s.startswith('_') 801 for c in str(s): 802 if ord(c) < ord('a') or ord(c) > ord('z'): 803 if not c.isdigit() and c != '_': 804 needs_quotes = True 805 break 806 if needs_quotes: 807 return '"' + s + '"' 808 return s 809 810 811def oracle_capital(s: str) -> str: 812 """ 813 Capitalize the string of an item on an Oracle database. 814 """ 815 return s 816 817 818def truncate_item_name(item: str, flavor: str) -> str: 819 """ 820 Truncate item names to stay within the database flavor's character limit. 821 822 Parameters 823 ---------- 824 item: str 825 The database item being referenced. This string is the "canonical" name internally. 826 827 flavor: str 828 The flavor of the database on which `item` resides. 829 830 Returns 831 ------- 832 The truncated string. 833 """ 834 from meerschaum.utils.misc import truncate_string_sections 835 return truncate_string_sections( 836 item, max_len=max_name_lens.get(flavor, max_name_lens['default']) 837 ) 838 839 840def build_where( 841 params: Dict[str, Any], 842 connector: Optional[meerschaum.connectors.sql.SQLConnector] = None, 843 with_where: bool = True, 844) -> str: 845 """ 846 Build the `WHERE` clause based on the input criteria. 847 848 Parameters 849 ---------- 850 params: Dict[str, Any]: 851 The keywords dictionary to convert into a WHERE clause. 852 If a value is a string which begins with an underscore, negate that value 853 (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`). 854 A value of `_None` will be interpreted as `IS NOT NULL`. 855 856 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 857 The Meerschaum SQLConnector that will be executing the query. 858 The connector is used to extract the SQL dialect. 859 860 with_where: bool, default True: 861 If `True`, include the leading `'WHERE'` string. 862 863 Returns 864 ------- 865 A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor. 866 867 Examples 868 -------- 869 ``` 870 >>> print(build_where({'foo': [1, 2, 3]})) 871 872 WHERE 873 "foo" IN ('1', '2', '3') 874 ``` 875 """ 876 import json 877 from meerschaum.config.static import STATIC_CONFIG 878 from meerschaum.utils.warnings import warn 879 from meerschaum.utils.dtypes import value_is_null, none_if_null 880 negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix'] 881 try: 882 params_json = json.dumps(params) 883 except Exception as e: 884 params_json = str(params) 885 bad_words = ['drop ', '--', ';'] 886 for word in bad_words: 887 if word in params_json.lower(): 888 warn(f"Aborting build_where() due to possible SQL injection.") 889 return '' 890 891 if connector is None: 892 from meerschaum import get_connector 893 connector = get_connector('sql') 894 where = "" 895 leading_and = "\n AND " 896 for key, value in params.items(): 897 _key = sql_item_name(key, connector.flavor, None) 898 ### search across a list (i.e. IN syntax) 899 if isinstance(value, Iterable) and not isinstance(value, (dict, str)): 900 includes = [ 901 none_if_null(item) 902 for item in value 903 if not str(item).startswith(negation_prefix) 904 ] 905 null_includes = [item for item in includes if item is None] 906 not_null_includes = [item for item in includes if item is not None] 907 excludes = [ 908 none_if_null(str(item)[len(negation_prefix):]) 909 for item in value 910 if str(item).startswith(negation_prefix) 911 ] 912 null_excludes = [item for item in excludes if item is None] 913 not_null_excludes = [item for item in excludes if item is not None] 914 915 if includes: 916 where += f"{leading_and}(" 917 if not_null_includes: 918 where += f"{_key} IN (" 919 for item in not_null_includes: 920 quoted_item = str(item).replace("'", "''") 921 where += f"'{quoted_item}', " 922 where = where[:-2] + ")" 923 if null_includes: 924 where += ("\n OR " if not_null_includes else "") + f"{_key} IS NULL" 925 if includes: 926 where += ")" 927 928 if excludes: 929 where += f"{leading_and}(" 930 if not_null_excludes: 931 where += f"{_key} NOT IN (" 932 for item in not_null_excludes: 933 quoted_item = str(item).replace("'", "''") 934 where += f"'{quoted_item}', " 935 where = where[:-2] + ")" 936 if null_excludes: 937 where += ("\n AND " if not_null_excludes else "") + f"{_key} IS NOT NULL" 938 if excludes: 939 where += ")" 940 941 continue 942 943 ### search a dictionary 944 elif isinstance(value, dict): 945 import json 946 where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'") 947 continue 948 949 eq_sign = '=' 950 is_null = 'IS NULL' 951 if value_is_null(str(value).lstrip(negation_prefix)): 952 value = ( 953 (negation_prefix + 'None') 954 if str(value).startswith(negation_prefix) 955 else None 956 ) 957 if str(value).startswith(negation_prefix): 958 value = str(value)[len(negation_prefix):] 959 eq_sign = '!=' 960 if value_is_null(value): 961 value = None 962 is_null = 'IS NOT NULL' 963 quoted_value = str(value).replace("'", "''") 964 where += ( 965 f"{leading_and}{_key} " 966 + (is_null if value is None else f"{eq_sign} '{quoted_value}'") 967 ) 968 969 if len(where) > 1: 970 where = ("\nWHERE\n " if with_where else '') + where[len(leading_and):] 971 return where 972 973 974def table_exists( 975 table: str, 976 connector: mrsm.connectors.sql.SQLConnector, 977 schema: Optional[str] = None, 978 debug: bool = False, 979) -> bool: 980 """Check if a table exists. 981 982 Parameters 983 ---------- 984 table: str: 985 The name of the table in question. 986 987 connector: mrsm.connectors.sql.SQLConnector 988 The connector to the database which holds the table. 989 990 schema: Optional[str], default None 991 Optionally specify the table schema. 992 Defaults to `connector.schema`. 993 994 debug: bool, default False : 995 Verbosity toggle. 996 997 Returns 998 ------- 999 A `bool` indicating whether or not the table exists on the database. 1000 """ 1001 sqlalchemy = mrsm.attempt_import('sqlalchemy') 1002 schema = schema or connector.schema 1003 insp = sqlalchemy.inspect(connector.engine) 1004 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1005 return insp.has_table(truncated_table_name, schema=schema) 1006 1007 1008def get_sqlalchemy_table( 1009 table: str, 1010 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 1011 schema: Optional[str] = None, 1012 refresh: bool = False, 1013 debug: bool = False, 1014) -> Union['sqlalchemy.Table', None]: 1015 """ 1016 Construct a SQLAlchemy table from its name. 1017 1018 Parameters 1019 ---------- 1020 table: str 1021 The name of the table on the database. Does not need to be escaped. 1022 1023 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 1024 The connector to the database which holds the table. 1025 1026 schema: Optional[str], default None 1027 Specify on which schema the table resides. 1028 Defaults to the schema set in `connector`. 1029 1030 refresh: bool, default False 1031 If `True`, rebuild the cached table object. 1032 1033 debug: bool, default False: 1034 Verbosity toggle. 1035 1036 Returns 1037 ------- 1038 A `sqlalchemy.Table` object for the table. 1039 1040 """ 1041 if connector is None: 1042 from meerschaum import get_connector 1043 connector = get_connector('sql') 1044 1045 if connector.flavor == 'duckdb': 1046 return None 1047 1048 from meerschaum.connectors.sql.tables import get_tables 1049 from meerschaum.utils.packages import attempt_import 1050 from meerschaum.utils.warnings import warn 1051 if refresh: 1052 connector.metadata.clear() 1053 tables = get_tables(mrsm_instance=connector, debug=debug, create=False) 1054 sqlalchemy = attempt_import('sqlalchemy') 1055 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1056 table_kwargs = { 1057 'autoload_with': connector.engine, 1058 } 1059 if schema: 1060 table_kwargs['schema'] = schema 1061 1062 if refresh or truncated_table_name not in tables: 1063 try: 1064 tables[truncated_table_name] = sqlalchemy.Table( 1065 truncated_table_name, 1066 connector.metadata, 1067 **table_kwargs 1068 ) 1069 except sqlalchemy.exc.NoSuchTableError as e: 1070 warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.") 1071 return None 1072 return tables[truncated_table_name] 1073 1074 1075def get_table_cols_types( 1076 table: str, 1077 connectable: Union[ 1078 'mrsm.connectors.sql.SQLConnector', 1079 'sqlalchemy.orm.session.Session', 1080 'sqlalchemy.engine.base.Engine' 1081 ], 1082 flavor: Optional[str] = None, 1083 schema: Optional[str] = None, 1084 database: Optional[str] = None, 1085 debug: bool = False, 1086) -> Dict[str, str]: 1087 """ 1088 Return a dictionary mapping a table's columns to data types. 1089 This is useful for inspecting tables creating during a not-yet-committed session. 1090 1091 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1092 Use this function if you are confident the table name is unique or if you have 1093 and explicit schema. 1094 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1095 1096 Parameters 1097 ---------- 1098 table: str 1099 The name of the table (unquoted). 1100 1101 connectable: Union[ 1102 'mrsm.connectors.sql.SQLConnector', 1103 'sqlalchemy.orm.session.Session', 1104 'sqlalchemy.engine.base.Engine' 1105 ] 1106 The connection object used to fetch the columns and types. 1107 1108 flavor: Optional[str], default None 1109 The database dialect flavor to use for the query. 1110 If omitted, default to `connectable.flavor`. 1111 1112 schema: Optional[str], default None 1113 If provided, restrict the query to this schema. 1114 1115 database: Optional[str]. default None 1116 If provided, restrict the query to this database. 1117 1118 Returns 1119 ------- 1120 A dictionary mapping column names to data types. 1121 """ 1122 from meerschaum.connectors import SQLConnector 1123 sqlalchemy = mrsm.attempt_import('sqlalchemy') 1124 flavor = flavor or getattr(connectable, 'flavor', None) 1125 if not flavor: 1126 raise ValueError("Please provide a database flavor.") 1127 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1128 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1129 if flavor in NO_SCHEMA_FLAVORS: 1130 schema = None 1131 if schema is None: 1132 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1133 if flavor in ('sqlite', 'duckdb', 'oracle'): 1134 database = None 1135 table_trunc = truncate_item_name(table, flavor=flavor) 1136 table_lower = table.lower() 1137 table_upper = table.upper() 1138 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1139 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1140 db_prefix = ( 1141 "tempdb." 1142 if flavor == 'mssql' and table.startswith('#') 1143 else "" 1144 ) 1145 1146 cols_types_query = sqlalchemy.text( 1147 columns_types_queries.get( 1148 flavor, 1149 columns_types_queries['default'] 1150 ).format( 1151 table=table, 1152 table_trunc=table_trunc, 1153 table_lower=table_lower, 1154 table_lower_trunc=table_lower_trunc, 1155 table_upper=table_upper, 1156 table_upper_trunc=table_upper_trunc, 1157 db_prefix=db_prefix, 1158 ) 1159 ) 1160 1161 cols = ['database', 'schema', 'table', 'column', 'type'] 1162 result_cols_ix = dict(enumerate(cols)) 1163 1164 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1165 if not debug_kwargs and debug: 1166 dprint(cols_types_query) 1167 1168 try: 1169 result_rows = ( 1170 [ 1171 row 1172 for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall() 1173 ] 1174 if flavor != 'duckdb' 1175 else [ 1176 tuple([doc[col] for col in cols]) 1177 for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records') 1178 ] 1179 ) 1180 cols_types_docs = [ 1181 { 1182 result_cols_ix[i]: val 1183 for i, val in enumerate(row) 1184 } 1185 for row in result_rows 1186 ] 1187 cols_types_docs_filtered = [ 1188 doc 1189 for doc in cols_types_docs 1190 if ( 1191 ( 1192 not schema 1193 or doc['schema'] == schema 1194 ) 1195 and 1196 ( 1197 not database 1198 or doc['database'] == database 1199 ) 1200 ) 1201 ] 1202 1203 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1204 if cols_types_docs and not cols_types_docs_filtered: 1205 cols_types_docs_filtered = cols_types_docs 1206 1207 return { 1208 ( 1209 doc['column'] 1210 if flavor != 'oracle' else ( 1211 ( 1212 doc['column'].lower() 1213 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1214 else doc['column'] 1215 ) 1216 ) 1217 ): doc['type'].upper() 1218 for doc in cols_types_docs_filtered 1219 } 1220 except Exception as e: 1221 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1222 return {} 1223 1224 1225def get_table_cols_indices( 1226 table: str, 1227 connectable: Union[ 1228 'mrsm.connectors.sql.SQLConnector', 1229 'sqlalchemy.orm.session.Session', 1230 'sqlalchemy.engine.base.Engine' 1231 ], 1232 flavor: Optional[str] = None, 1233 schema: Optional[str] = None, 1234 database: Optional[str] = None, 1235 debug: bool = False, 1236) -> Dict[str, List[str]]: 1237 """ 1238 Return a dictionary mapping a table's columns to lists of indices. 1239 This is useful for inspecting tables creating during a not-yet-committed session. 1240 1241 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1242 Use this function if you are confident the table name is unique or if you have 1243 and explicit schema. 1244 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1245 1246 Parameters 1247 ---------- 1248 table: str 1249 The name of the table (unquoted). 1250 1251 connectable: Union[ 1252 'mrsm.connectors.sql.SQLConnector', 1253 'sqlalchemy.orm.session.Session', 1254 'sqlalchemy.engine.base.Engine' 1255 ] 1256 The connection object used to fetch the columns and types. 1257 1258 flavor: Optional[str], default None 1259 The database dialect flavor to use for the query. 1260 If omitted, default to `connectable.flavor`. 1261 1262 schema: Optional[str], default None 1263 If provided, restrict the query to this schema. 1264 1265 database: Optional[str]. default None 1266 If provided, restrict the query to this database. 1267 1268 Returns 1269 ------- 1270 A dictionary mapping column names to a list of indices. 1271 """ 1272 from collections import defaultdict 1273 from meerschaum.connectors import SQLConnector 1274 sqlalchemy = mrsm.attempt_import('sqlalchemy') 1275 flavor = flavor or getattr(connectable, 'flavor', None) 1276 if not flavor: 1277 raise ValueError("Please provide a database flavor.") 1278 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1279 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1280 if flavor in NO_SCHEMA_FLAVORS: 1281 schema = None 1282 if schema is None: 1283 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1284 if flavor in ('sqlite', 'duckdb', 'oracle'): 1285 database = None 1286 table_trunc = truncate_item_name(table, flavor=flavor) 1287 table_lower = table.lower() 1288 table_upper = table.upper() 1289 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1290 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1291 db_prefix = ( 1292 "tempdb." 1293 if flavor == 'mssql' and table.startswith('#') 1294 else "" 1295 ) 1296 1297 cols_indices_query = sqlalchemy.text( 1298 columns_indices_queries.get( 1299 flavor, 1300 columns_indices_queries['default'] 1301 ).format( 1302 table=table, 1303 table_trunc=table_trunc, 1304 table_lower=table_lower, 1305 table_lower_trunc=table_lower_trunc, 1306 table_upper=table_upper, 1307 table_upper_trunc=table_upper_trunc, 1308 db_prefix=db_prefix, 1309 schema=schema, 1310 ) 1311 ) 1312 1313 cols = ['database', 'schema', 'table', 'column', 'index', 'index_type'] 1314 result_cols_ix = dict(enumerate(cols)) 1315 1316 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1317 if not debug_kwargs and debug: 1318 dprint(cols_indices_query) 1319 1320 try: 1321 result_rows = ( 1322 [ 1323 row 1324 for row in connectable.execute(cols_indices_query, **debug_kwargs).fetchall() 1325 ] 1326 if flavor != 'duckdb' 1327 else [ 1328 tuple([doc[col] for col in cols]) 1329 for doc in connectable.read(cols_indices_query, debug=debug).to_dict(orient='records') 1330 ] 1331 ) 1332 cols_types_docs = [ 1333 { 1334 result_cols_ix[i]: val 1335 for i, val in enumerate(row) 1336 } 1337 for row in result_rows 1338 ] 1339 cols_types_docs_filtered = [ 1340 doc 1341 for doc in cols_types_docs 1342 if ( 1343 ( 1344 not schema 1345 or doc['schema'] == schema 1346 ) 1347 and 1348 ( 1349 not database 1350 or doc['database'] == database 1351 ) 1352 ) 1353 ] 1354 1355 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1356 if cols_types_docs and not cols_types_docs_filtered: 1357 cols_types_docs_filtered = cols_types_docs 1358 1359 cols_indices = defaultdict(lambda: []) 1360 for doc in cols_types_docs_filtered: 1361 col = ( 1362 doc['column'] 1363 if flavor != 'oracle' 1364 else ( 1365 doc['column'].lower() 1366 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1367 else doc['column'] 1368 ) 1369 ) 1370 cols_indices[col].append( 1371 { 1372 'name': doc.get('index', None), 1373 'type': doc.get('index_type', None), 1374 } 1375 ) 1376 1377 return dict(cols_indices) 1378 except Exception as e: 1379 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1380 return {} 1381 1382 1383def get_update_queries( 1384 target: str, 1385 patch: str, 1386 connectable: Union[ 1387 mrsm.connectors.sql.SQLConnector, 1388 'sqlalchemy.orm.session.Session' 1389 ], 1390 join_cols: Iterable[str], 1391 flavor: Optional[str] = None, 1392 upsert: bool = False, 1393 datetime_col: Optional[str] = None, 1394 schema: Optional[str] = None, 1395 patch_schema: Optional[str] = None, 1396 debug: bool = False, 1397) -> List[str]: 1398 """ 1399 Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table. 1400 1401 Parameters 1402 ---------- 1403 target: str 1404 The name of the target table. 1405 1406 patch: str 1407 The name of the patch table. This should have the same shape as the target. 1408 1409 connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session] 1410 The `SQLConnector` or SQLAlchemy session which will later execute the queries. 1411 1412 join_cols: List[str] 1413 The columns to use to join the patch to the target. 1414 1415 flavor: Optional[str], default None 1416 If using a SQLAlchemy session, provide the expected database flavor. 1417 1418 upsert: bool, default False 1419 If `True`, return an upsert query rather than an update. 1420 1421 datetime_col: Optional[str], default None 1422 If provided, bound the join query using this column as the datetime index. 1423 This must be present on both tables. 1424 1425 schema: Optional[str], default None 1426 If provided, use this schema when quoting the target table. 1427 Defaults to `connector.schema`. 1428 1429 patch_schema: Optional[str], default None 1430 If provided, use this schema when quoting the patch table. 1431 Defaults to `schema`. 1432 1433 debug: bool, default False 1434 Verbosity toggle. 1435 1436 Returns 1437 ------- 1438 A list of query strings to perform the update operation. 1439 """ 1440 from meerschaum.connectors import SQLConnector 1441 from meerschaum.utils.debug import dprint 1442 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1443 flavor = flavor or (connectable.flavor if isinstance(connectable, SQLConnector) else None) 1444 if not flavor: 1445 raise ValueError("Provide a flavor if using a SQLAlchemy session.") 1446 if ( 1447 flavor == 'sqlite' 1448 and isinstance(connectable, SQLConnector) 1449 and connectable.db_version < '3.33.0' 1450 ): 1451 flavor = 'sqlite_delete_insert' 1452 flavor_key = (f'{flavor}-upsert' if upsert else flavor) 1453 base_queries = update_queries.get( 1454 flavor_key, 1455 update_queries['default'] 1456 ) 1457 if not isinstance(base_queries, list): 1458 base_queries = [base_queries] 1459 schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None) 1460 patch_schema = patch_schema or schema 1461 target_table_columns = get_table_cols_types( 1462 target, 1463 connectable, 1464 flavor=flavor, 1465 schema=schema, 1466 debug=debug, 1467 ) 1468 patch_table_columns = get_table_cols_types( 1469 patch, 1470 connectable, 1471 flavor=flavor, 1472 schema=patch_schema, 1473 debug=debug, 1474 ) 1475 1476 patch_cols_str = ', '.join( 1477 [ 1478 sql_item_name(col, flavor) 1479 for col in patch_table_columns 1480 ] 1481 ) 1482 patch_cols_prefixed_str = ', '.join( 1483 [ 1484 'p.' + sql_item_name(col, flavor) 1485 for col in patch_table_columns 1486 ] 1487 ) 1488 1489 join_cols_str = ', '.join( 1490 [ 1491 sql_item_name(col, flavor) 1492 for col in join_cols 1493 ] 1494 ) 1495 1496 value_cols = [] 1497 join_cols_types = [] 1498 if debug: 1499 dprint("target_table_columns:") 1500 mrsm.pprint(target_table_columns) 1501 for c_name, c_type in target_table_columns.items(): 1502 if c_name not in patch_table_columns: 1503 continue 1504 if flavor in DB_FLAVORS_CAST_DTYPES: 1505 c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type) 1506 ( 1507 join_cols_types 1508 if c_name in join_cols 1509 else value_cols 1510 ).append((c_name, c_type)) 1511 if debug: 1512 dprint(f"value_cols: {value_cols}") 1513 1514 if not join_cols_types: 1515 return [] 1516 if not value_cols and not upsert: 1517 return [] 1518 1519 coalesce_join_cols_str = ', '.join( 1520 [ 1521 'COALESCE(' 1522 + sql_item_name(c_name, flavor) 1523 + ', ' 1524 + get_null_replacement(c_type, flavor) 1525 + ')' 1526 for c_name, c_type in join_cols_types 1527 ] 1528 ) 1529 1530 update_or_nothing = ('UPDATE' if value_cols else 'NOTHING') 1531 1532 def sets_subquery(l_prefix: str, r_prefix: str): 1533 if not value_cols: 1534 return '' 1535 return 'SET ' + ',\n'.join([ 1536 ( 1537 l_prefix + sql_item_name(c_name, flavor, None) 1538 + ' = ' 1539 + ('CAST(' if flavor != 'sqlite' else '') 1540 + r_prefix 1541 + sql_item_name(c_name, flavor, None) 1542 + (' AS ' if flavor != 'sqlite' else '') 1543 + (c_type.replace('_', ' ') if flavor != 'sqlite' else '') 1544 + (')' if flavor != 'sqlite' else '') 1545 ) for c_name, c_type in value_cols 1546 ]) 1547 1548 def and_subquery(l_prefix: str, r_prefix: str): 1549 return '\nAND\n'.join([ 1550 ( 1551 "COALESCE(" 1552 + l_prefix 1553 + sql_item_name(c_name, flavor, None) 1554 + ", " 1555 + get_null_replacement(c_type, flavor) 1556 + ")" 1557 + ' = ' 1558 + "COALESCE(" 1559 + r_prefix 1560 + sql_item_name(c_name, flavor, None) 1561 + ", " 1562 + get_null_replacement(c_type, flavor) 1563 + ")" 1564 ) for c_name, c_type in join_cols_types 1565 ]) 1566 1567 target_table_name = sql_item_name(target, flavor, schema) 1568 patch_table_name = sql_item_name(patch, flavor, patch_schema) 1569 dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None 1570 date_bounds_subquery = ( 1571 f""" 1572 f.{dt_col_name} >= (SELECT MIN({dt_col_name}) FROM {patch_table_name}) 1573 AND f.{dt_col_name} <= (SELECT MAX({dt_col_name}) FROM {patch_table_name}) 1574 """ 1575 if datetime_col 1576 else "1 = 1" 1577 ) 1578 1579 ### NOTE: MSSQL upserts must exclude the update portion if only upserting indices. 1580 when_matched_update_sets_subquery_none = "" if not value_cols else ( 1581 "WHEN MATCHED THEN" 1582 f" UPDATE {sets_subquery('', 'p.')}" 1583 ) 1584 1585 cols_equal_values = '\n,'.join( 1586 [ 1587 f"{sql_item_name(c_name, flavor)} = VALUES({sql_item_name(c_name, flavor)})" 1588 for c_name, c_type in value_cols 1589 ] 1590 ) 1591 on_duplicate_key_update = ( 1592 "ON DUPLICATE KEY UPDATE" 1593 if value_cols 1594 else "" 1595 ) 1596 ignore = "IGNORE " if not value_cols else "" 1597 1598 return [ 1599 base_query.format( 1600 sets_subquery_none=sets_subquery('', 'p.'), 1601 sets_subquery_none_excluded=sets_subquery('', 'EXCLUDED.'), 1602 sets_subquery_f=sets_subquery('f.', 'p.'), 1603 and_subquery_f=and_subquery('p.', 'f.'), 1604 and_subquery_t=and_subquery('p.', 't.'), 1605 target_table_name=target_table_name, 1606 patch_table_name=patch_table_name, 1607 patch_cols_str=patch_cols_str, 1608 patch_cols_prefixed_str=patch_cols_prefixed_str, 1609 date_bounds_subquery=date_bounds_subquery, 1610 join_cols_str=join_cols_str, 1611 coalesce_join_cols_str=coalesce_join_cols_str, 1612 update_or_nothing=update_or_nothing, 1613 when_matched_update_sets_subquery_none=when_matched_update_sets_subquery_none, 1614 cols_equal_values=cols_equal_values, 1615 on_duplicate_key_update=on_duplicate_key_update, 1616 ignore=ignore, 1617 ) 1618 for base_query in base_queries 1619 ] 1620 1621 1622def get_null_replacement(typ: str, flavor: str) -> str: 1623 """ 1624 Return a value that may temporarily be used in place of NULL for this type. 1625 1626 Parameters 1627 ---------- 1628 typ: str 1629 The typ to be converted to NULL. 1630 1631 flavor: str 1632 The database flavor for which this value will be used. 1633 1634 Returns 1635 ------- 1636 A value which may stand in place of NULL for this type. 1637 `'None'` is returned if a value cannot be determined. 1638 """ 1639 from meerschaum.utils.dtypes import are_dtypes_equal 1640 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1641 if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'): 1642 return '-987654321' 1643 if 'bool' in typ.lower() or typ.lower() == 'bit': 1644 bool_typ = ( 1645 PD_TO_DB_DTYPES_FLAVORS 1646 .get('bool', {}) 1647 .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default']) 1648 ) 1649 if flavor in DB_FLAVORS_CAST_DTYPES: 1650 bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ) 1651 val_to_cast = ( 1652 -987654321 1653 if flavor in ('mysql', 'mariadb') 1654 else 0 1655 ) 1656 return f'CAST({val_to_cast} AS {bool_typ})' 1657 if 'time' in typ.lower() or 'date' in typ.lower(): 1658 return dateadd_str(flavor=flavor, begin='1900-01-01') 1659 if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',): 1660 return '-987654321.0' 1661 if flavor == 'oracle' and typ.lower().split('(', maxsplit=1)[0] == 'char': 1662 return "'-987654321'" 1663 if typ.lower() in ('uniqueidentifier', 'guid', 'uuid'): 1664 magic_val = 'DEADBEEF-ABBA-BABE-CAFE-DECAFC0FFEE5' 1665 if flavor == 'mssql': 1666 return f"CAST('{magic_val}' AS UNIQUEIDENTIFIER)" 1667 return f"'{magic_val}'" 1668 return ('n' if flavor == 'oracle' else '') + "'-987654321'" 1669 1670 1671def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]: 1672 """ 1673 Fetch the database version if possible. 1674 """ 1675 version_name = sql_item_name('version', conn.flavor, None) 1676 version_query = version_queries.get( 1677 conn.flavor, 1678 version_queries['default'] 1679 ).format(version_name=version_name) 1680 return conn.value(version_query, debug=debug) 1681 1682 1683def get_rename_table_queries( 1684 old_table: str, 1685 new_table: str, 1686 flavor: str, 1687 schema: Optional[str] = None, 1688) -> List[str]: 1689 """ 1690 Return queries to alter a table's name. 1691 1692 Parameters 1693 ---------- 1694 old_table: str 1695 The unquoted name of the old table. 1696 1697 new_table: str 1698 The unquoted name of the new table. 1699 1700 flavor: str 1701 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1702 1703 schema: Optional[str], default None 1704 The schema on which the table resides. 1705 1706 Returns 1707 ------- 1708 A list of `ALTER TABLE` or equivalent queries for the database flavor. 1709 """ 1710 old_table_name = sql_item_name(old_table, flavor, schema) 1711 new_table_name = sql_item_name(new_table, flavor, None) 1712 tmp_table = '_tmp_rename_' + new_table 1713 tmp_table_name = sql_item_name(tmp_table, flavor, schema) 1714 if flavor == 'mssql': 1715 return [f"EXEC sp_rename '{old_table}', '{new_table}'"] 1716 1717 if_exists_str = "IF EXISTS" if flavor in DROP_IF_EXISTS_FLAVORS else "" 1718 if flavor == 'duckdb': 1719 return ( 1720 get_create_table_queries( 1721 f"SELECT * FROM {old_table_name}", 1722 tmp_table, 1723 'duckdb', 1724 schema, 1725 ) + get_create_table_queries( 1726 f"SELECT * FROM {tmp_table_name}", 1727 new_table, 1728 'duckdb', 1729 schema, 1730 ) + [ 1731 f"DROP TABLE {if_exists_str} {tmp_table_name}", 1732 f"DROP TABLE {if_exists_str} {old_table_name}", 1733 ] 1734 ) 1735 1736 return [f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}"] 1737 1738 1739def get_create_table_query( 1740 query_or_dtypes: Union[str, Dict[str, str]], 1741 new_table: str, 1742 flavor: str, 1743 schema: Optional[str] = None, 1744) -> str: 1745 """ 1746 NOTE: This function is deprecated. Use `get_create_table_queries()` instead. 1747 1748 Return a query to create a new table from a `SELECT` query. 1749 1750 Parameters 1751 ---------- 1752 query: Union[str, Dict[str, str]] 1753 The select query to use for the creation of the table. 1754 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 1755 1756 new_table: str 1757 The unquoted name of the new table. 1758 1759 flavor: str 1760 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 1761 1762 schema: Optional[str], default None 1763 The schema on which the table will reside. 1764 1765 Returns 1766 ------- 1767 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1768 """ 1769 return get_create_table_queries( 1770 query_or_dtypes, 1771 new_table, 1772 flavor, 1773 schema=schema, 1774 primary_key=None, 1775 )[0] 1776 1777 1778def get_create_table_queries( 1779 query_or_dtypes: Union[str, Dict[str, str]], 1780 new_table: str, 1781 flavor: str, 1782 schema: Optional[str] = None, 1783 primary_key: Optional[str] = None, 1784 autoincrement: bool = False, 1785 datetime_column: Optional[str] = None, 1786) -> List[str]: 1787 """ 1788 Return a query to create a new table from a `SELECT` query or a `dtypes` dictionary. 1789 1790 Parameters 1791 ---------- 1792 query_or_dtypes: Union[str, Dict[str, str]] 1793 The select query to use for the creation of the table. 1794 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 1795 1796 new_table: str 1797 The unquoted name of the new table. 1798 1799 flavor: str 1800 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 1801 1802 schema: Optional[str], default None 1803 The schema on which the table will reside. 1804 1805 primary_key: Optional[str], default None 1806 If provided, designate this column as the primary key in the new table. 1807 1808 autoincrement: bool, default False 1809 If `True` and `primary_key` is provided, create the `primary_key` column 1810 as an auto-incrementing integer column. 1811 1812 datetime_column: Optional[str], default None 1813 If provided, include this column in the primary key. 1814 Applicable to TimescaleDB only. 1815 1816 Returns 1817 ------- 1818 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1819 """ 1820 if not isinstance(query_or_dtypes, (str, dict)): 1821 raise TypeError("`query_or_dtypes` must be a query or a dtypes dictionary.") 1822 1823 method = ( 1824 _get_create_table_query_from_cte 1825 if isinstance(query_or_dtypes, str) 1826 else _get_create_table_query_from_dtypes 1827 ) 1828 return method( 1829 query_or_dtypes, 1830 new_table, 1831 flavor, 1832 schema=schema, 1833 primary_key=primary_key, 1834 autoincrement=(autoincrement and flavor not in SKIP_AUTO_INCREMENT_FLAVORS), 1835 datetime_column=datetime_column, 1836 ) 1837 1838 1839def _get_create_table_query_from_dtypes( 1840 dtypes: Dict[str, str], 1841 new_table: str, 1842 flavor: str, 1843 schema: Optional[str] = None, 1844 primary_key: Optional[str] = None, 1845 autoincrement: bool = False, 1846 datetime_column: Optional[str] = None, 1847) -> List[str]: 1848 """ 1849 Create a new table from a `dtypes` dictionary. 1850 """ 1851 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, AUTO_INCREMENT_COLUMN_FLAVORS 1852 if not dtypes and not primary_key: 1853 raise ValueError(f"Expecting columns for table '{new_table}'.") 1854 1855 if flavor in SKIP_AUTO_INCREMENT_FLAVORS: 1856 autoincrement = False 1857 1858 cols_types = ( 1859 [(primary_key, get_db_type_from_pd_type(dtypes.get(primary_key, 'int'), flavor=flavor))] 1860 if primary_key 1861 else [] 1862 ) + [ 1863 (col, get_db_type_from_pd_type(typ, flavor=flavor)) 1864 for col, typ in dtypes.items() 1865 if col != primary_key 1866 ] 1867 1868 table_name = sql_item_name(new_table, schema=schema, flavor=flavor) 1869 primary_key_name = sql_item_name(primary_key, flavor) if primary_key else None 1870 datetime_column_name = sql_item_name(datetime_column, flavor) if datetime_column else None 1871 query = f"CREATE TABLE {table_name} (" 1872 if primary_key: 1873 col_db_type = cols_types[0][1] 1874 auto_increment_str = (' ' + AUTO_INCREMENT_COLUMN_FLAVORS.get( 1875 flavor, 1876 AUTO_INCREMENT_COLUMN_FLAVORS['default'] 1877 )) if autoincrement or primary_key not in dtypes else '' 1878 col_name = sql_item_name(primary_key, flavor=flavor, schema=None) 1879 1880 if flavor == 'sqlite': 1881 query += ( 1882 f"\n {col_name} " 1883 + (f"{col_db_type}" if not auto_increment_str else 'INTEGER') 1884 + f" PRIMARY KEY{auto_increment_str} NOT NULL," 1885 ) 1886 elif flavor == 'oracle': 1887 query += f"\n {col_name} {col_db_type} {auto_increment_str} PRIMARY KEY," 1888 elif flavor == 'timescaledb' and datetime_column and datetime_column != primary_key: 1889 query += f"\n {col_name} {col_db_type}{auto_increment_str} NOT NULL," 1890 else: 1891 query += f"\n {col_name} {col_db_type} PRIMARY KEY{auto_increment_str} NOT NULL," 1892 1893 for col, db_type in cols_types: 1894 if col == primary_key: 1895 continue 1896 col_name = sql_item_name(col, schema=None, flavor=flavor) 1897 query += f"\n {col_name} {db_type}," 1898 if ( 1899 flavor == 'timescaledb' 1900 and datetime_column 1901 and primary_key 1902 and datetime_column != primary_key 1903 ): 1904 query += f"\n PRIMARY KEY({datetime_column_name}, {primary_key_name})," 1905 query = query[:-1] 1906 query += "\n)" 1907 1908 queries = [query] 1909 return queries 1910 1911 1912def _get_create_table_query_from_cte( 1913 query: str, 1914 new_table: str, 1915 flavor: str, 1916 schema: Optional[str] = None, 1917 primary_key: Optional[str] = None, 1918 autoincrement: bool = False, 1919 datetime_column: Optional[str] = None, 1920) -> List[str]: 1921 """ 1922 Create a new table from a CTE query. 1923 """ 1924 import textwrap 1925 from meerschaum.utils.dtypes.sql import AUTO_INCREMENT_COLUMN_FLAVORS 1926 create_cte = 'create_query' 1927 create_cte_name = sql_item_name(create_cte, flavor, None) 1928 new_table_name = sql_item_name(new_table, flavor, schema) 1929 primary_key_constraint_name = ( 1930 sql_item_name(f'pk_{new_table}', flavor, None) 1931 if primary_key 1932 else None 1933 ) 1934 primary_key_name = ( 1935 sql_item_name(primary_key, flavor, None) 1936 if primary_key 1937 else None 1938 ) 1939 datetime_column_name = ( 1940 sql_item_name(datetime_column, flavor) 1941 if datetime_column 1942 else None 1943 ) 1944 if flavor in ('mssql',): 1945 query = query.lstrip() 1946 if 'with ' in query.lower(): 1947 final_select_ix = query.lower().rfind('select') 1948 create_table_query = ( 1949 query[:final_select_ix].rstrip() + ',\n' 1950 + f"{create_cte_name} AS (\n" 1951 + query[final_select_ix:] 1952 + "\n)\n" 1953 + f"SELECT *\nINTO {new_table_name}\nFROM {create_cte_name}" 1954 ) 1955 else: 1956 create_table_query = f""" 1957 SELECT * 1958 INTO {new_table_name} 1959 FROM ({query}) AS {create_cte_name} 1960 """ 1961 1962 alter_type_query = f""" 1963 ALTER TABLE {new_table_name} 1964 ADD CONSTRAINT {primary_key_constraint_name} PRIMARY KEY ({primary_key_name}) 1965 """ 1966 elif flavor in (None,): 1967 create_table_query = f""" 1968 WITH {create_cte_name} AS ({query}) 1969 CREATE TABLE {new_table_name} AS 1970 SELECT * 1971 FROM {create_cte_name} 1972 """ 1973 1974 alter_type_query = f""" 1975 ALTER TABLE {new_table_name} 1976 ADD PRIMARY KEY ({primary_key_name}) 1977 """ 1978 elif flavor in ('sqlite', 'mysql', 'mariadb', 'duckdb', 'oracle'): 1979 create_table_query = f""" 1980 CREATE TABLE {new_table_name} AS 1981 SELECT * 1982 FROM ({query})""" + (f""" AS {create_cte_name}""" if flavor != 'oracle' else '') + """ 1983 """ 1984 1985 alter_type_query = f""" 1986 ALTER TABLE {new_table_name} 1987 ADD PRIMARY KEY ({primary_key_name}) 1988 """ 1989 elif flavor == 'timescaledb' and datetime_column and datetime_column != primary_key: 1990 create_table_query = f""" 1991 SELECT * 1992 INTO {new_table_name} 1993 FROM ({query}) AS {create_cte_name} 1994 """ 1995 1996 alter_type_query = f""" 1997 ALTER TABLE {new_table_name} 1998 ADD PRIMARY KEY ({datetime_column_name}, {primary_key_name}) 1999 """ 2000 else: 2001 create_table_query = f""" 2002 SELECT * 2003 INTO {new_table_name} 2004 FROM ({query}) AS {create_cte_name} 2005 """ 2006 2007 alter_type_query = f""" 2008 ALTER TABLE {new_table_name} 2009 ADD PRIMARY KEY ({primary_key_name}) 2010 """ 2011 2012 create_table_query = textwrap.dedent(create_table_query) 2013 if not primary_key: 2014 return [create_table_query] 2015 2016 alter_type_query = textwrap.dedent(alter_type_query) 2017 2018 return [ 2019 create_table_query, 2020 alter_type_query, 2021 ] 2022 2023 2024def wrap_query_with_cte( 2025 sub_query: str, 2026 parent_query: str, 2027 flavor: str, 2028 cte_name: str = "src", 2029) -> str: 2030 """ 2031 Wrap a subquery in a CTE and append an encapsulating query. 2032 2033 Parameters 2034 ---------- 2035 sub_query: str 2036 The query to be referenced. This may itself contain CTEs. 2037 Unless `cte_name` is provided, this will be aliased as `src`. 2038 2039 parent_query: str 2040 The larger query to append which references the subquery. 2041 This must not contain CTEs. 2042 2043 flavor: str 2044 The database flavor, e.g. `'mssql'`. 2045 2046 cte_name: str, default 'src' 2047 The CTE alias, defaults to `src`. 2048 2049 Returns 2050 ------- 2051 An encapsulating query which allows you to treat `sub_query` as a temporary table. 2052 2053 Examples 2054 -------- 2055 2056 ```python 2057 from meerschaum.utils.sql import wrap_query_with_cte 2058 sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo" 2059 parent_query = "SELECT newval * 3 FROM src" 2060 query = wrap_query_with_cte(sub_query, parent_query, 'mssql') 2061 print(query) 2062 # WITH foo AS (SELECT 1 AS val), 2063 # [src] AS ( 2064 # SELECT (val * 2) AS newval FROM foo 2065 # ) 2066 # SELECT newval * 3 FROM src 2067 ``` 2068 2069 """ 2070 sub_query = sub_query.lstrip() 2071 cte_name_quoted = sql_item_name(cte_name, flavor, None) 2072 2073 if flavor in NO_CTE_FLAVORS: 2074 return ( 2075 parent_query 2076 .replace(cte_name_quoted, '--MRSM_SUBQUERY--') 2077 .replace(cte_name, '--MRSM_SUBQUERY--') 2078 .replace('--MRSM_SUBQUERY--', f"(\n{sub_query}\n) AS {cte_name_quoted}") 2079 ) 2080 2081 if 'with ' in sub_query.lower(): 2082 final_select_ix = sub_query.lower().rfind('select') 2083 return ( 2084 sub_query[:final_select_ix].rstrip() + ',\n' 2085 + f"{cte_name_quoted} AS (\n" 2086 + ' ' + sub_query[final_select_ix:] 2087 + "\n)\n" 2088 + parent_query 2089 ) 2090 2091 return ( 2092 f"WITH {cte_name_quoted} AS (\n" 2093 f" {sub_query}\n" 2094 f")\n{parent_query}" 2095 ) 2096 2097 2098def format_cte_subquery( 2099 sub_query: str, 2100 flavor: str, 2101 sub_name: str = 'src', 2102 cols_to_select: Union[List[str], str] = '*', 2103) -> str: 2104 """ 2105 Given a subquery, build a wrapper query that selects from the CTE subquery. 2106 2107 Parameters 2108 ---------- 2109 sub_query: str 2110 The subquery to wrap. 2111 2112 flavor: str 2113 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 2114 2115 sub_name: str, default 'src' 2116 If possible, give this name to the CTE (must be unquoted). 2117 2118 cols_to_select: Union[List[str], str], default '' 2119 If specified, choose which columns to select from the CTE. 2120 If a list of strings is provided, each item will be quoted and joined with commas. 2121 If a string is given, assume it is quoted and insert it into the query. 2122 2123 Returns 2124 ------- 2125 A wrapper query that selects from the CTE. 2126 """ 2127 quoted_sub_name = sql_item_name(sub_name, flavor, None) 2128 cols_str = ( 2129 cols_to_select 2130 if isinstance(cols_to_select, str) 2131 else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select]) 2132 ) 2133 parent_query = ( 2134 f"SELECT {cols_str}\n" 2135 f"FROM {quoted_sub_name}" 2136 ) 2137 return wrap_query_with_cte(sub_query, parent_query, flavor, cte_name=sub_name) 2138 2139 2140def session_execute( 2141 session: 'sqlalchemy.orm.session.Session', 2142 queries: Union[List[str], str], 2143 with_results: bool = False, 2144 debug: bool = False, 2145) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]: 2146 """ 2147 Similar to `SQLConnector.exec_queries()`, execute a list of queries 2148 and roll back when one fails. 2149 2150 Parameters 2151 ---------- 2152 session: sqlalchemy.orm.session.Session 2153 A SQLAlchemy session representing a transaction. 2154 2155 queries: Union[List[str], str] 2156 A query or list of queries to be executed. 2157 If a query fails, roll back the session. 2158 2159 with_results: bool, default False 2160 If `True`, return a list of result objects. 2161 2162 Returns 2163 ------- 2164 A `SuccessTuple` indicating the queries were successfully executed. 2165 If `with_results`, return the `SuccessTuple` and a list of results. 2166 """ 2167 sqlalchemy = mrsm.attempt_import('sqlalchemy') 2168 if not isinstance(queries, list): 2169 queries = [queries] 2170 successes, msgs, results = [], [], [] 2171 for query in queries: 2172 query_text = sqlalchemy.text(query) 2173 fail_msg = "Failed to execute queries." 2174 try: 2175 result = session.execute(query_text) 2176 query_success = result is not None 2177 query_msg = "Success" if query_success else fail_msg 2178 except Exception as e: 2179 query_success = False 2180 query_msg = f"{fail_msg}\n{e}" 2181 result = None 2182 successes.append(query_success) 2183 msgs.append(query_msg) 2184 results.append(result) 2185 if not query_success: 2186 session.rollback() 2187 break 2188 success, msg = all(successes), '\n'.join(msgs) 2189 if with_results: 2190 return (success, msg), results 2191 return success, msg 2192 2193 2194def get_reset_autoincrement_queries( 2195 table: str, 2196 column: str, 2197 connector: mrsm.connectors.SQLConnector, 2198 schema: Optional[str] = None, 2199 debug: bool = False, 2200) -> List[str]: 2201 """ 2202 Return a list of queries to reset a table's auto-increment counter to the next largest value. 2203 2204 Parameters 2205 ---------- 2206 table: str 2207 The name of the table on which the auto-incrementing column exists. 2208 2209 column: str 2210 The name of the auto-incrementing column. 2211 2212 connector: mrsm.connectors.SQLConnector 2213 The SQLConnector to the database on which the table exists. 2214 2215 schema: Optional[str], default None 2216 The schema of the table. Defaults to `connector.schema`. 2217 2218 Returns 2219 ------- 2220 A list of queries to be executed to reset the auto-incrementing column. 2221 """ 2222 if not table_exists(table, connector, schema=schema, debug=debug): 2223 return [] 2224 2225 schema = schema or connector.schema 2226 max_id_name = sql_item_name('max_id', connector.flavor) 2227 table_name = sql_item_name(table, connector.flavor, schema) 2228 table_trunc = truncate_item_name(table, connector.flavor) 2229 table_seq_name = sql_item_name(table + '_' + column + '_seq', connector.flavor, schema) 2230 column_name = sql_item_name(column, connector.flavor) 2231 if connector.flavor == 'oracle': 2232 df = connector.read(f""" 2233 SELECT SEQUENCE_NAME 2234 FROM ALL_TAB_IDENTITY_COLS 2235 WHERE TABLE_NAME IN '{table_trunc.upper()}' 2236 """, debug=debug) 2237 if len(df) > 0: 2238 table_seq_name = df['sequence_name'][0] 2239 2240 max_id = connector.value( 2241 f""" 2242 SELECT COALESCE(MAX({column_name}), 0) AS {max_id_name} 2243 FROM {table_name} 2244 """, 2245 debug=debug, 2246 ) 2247 if max_id is None: 2248 return [] 2249 2250 reset_queries = reset_autoincrement_queries.get( 2251 connector.flavor, 2252 reset_autoincrement_queries['default'] 2253 ) 2254 if not isinstance(reset_queries, list): 2255 reset_queries = [reset_queries] 2256 2257 return [ 2258 query.format( 2259 column=column, 2260 column_name=column_name, 2261 table=table, 2262 table_name=table_name, 2263 table_seq_name=table_seq_name, 2264 val=(max_id), 2265 ) 2266 for query in reset_queries 2267 ]
487def clean(substring: str) -> str: 488 """ 489 Ensure a substring is clean enough to be inserted into a SQL query. 490 Raises an exception when banned words are used. 491 """ 492 from meerschaum.utils.warnings import error 493 banned_symbols = [';', '--', 'drop ',] 494 for symbol in banned_symbols: 495 if symbol in str(substring).lower(): 496 error(f"Invalid string: '{substring}'")
Ensure a substring is clean enough to be inserted into a SQL query. Raises an exception when banned words are used.
499def dateadd_str( 500 flavor: str = 'postgresql', 501 datepart: str = 'day', 502 number: Union[int, float] = 0, 503 begin: Union[str, datetime, int] = 'now' 504) -> str: 505 """ 506 Generate a `DATEADD` clause depending on database flavor. 507 508 Parameters 509 ---------- 510 flavor: str, default `'postgresql'` 511 SQL database flavor, e.g. `'postgresql'`, `'sqlite'`. 512 513 Currently supported flavors: 514 515 - `'postgresql'` 516 - `'timescaledb'` 517 - `'citus'` 518 - `'cockroachdb'` 519 - `'duckdb'` 520 - `'mssql'` 521 - `'mysql'` 522 - `'mariadb'` 523 - `'sqlite'` 524 - `'oracle'` 525 526 datepart: str, default `'day'` 527 Which part of the date to modify. Supported values: 528 529 - `'year'` 530 - `'month'` 531 - `'day'` 532 - `'hour'` 533 - `'minute'` 534 - `'second'` 535 536 number: Union[int, float], default `0` 537 How many units to add to the date part. 538 539 begin: Union[str, datetime], default `'now'` 540 Base datetime to which to add dateparts. 541 542 Returns 543 ------- 544 The appropriate `DATEADD` string for the corresponding database flavor. 545 546 Examples 547 -------- 548 >>> dateadd_str( 549 ... flavor = 'mssql', 550 ... begin = datetime(2022, 1, 1, 0, 0), 551 ... number = 1, 552 ... ) 553 "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME))" 554 >>> dateadd_str( 555 ... flavor = 'postgresql', 556 ... begin = datetime(2022, 1, 1, 0, 0), 557 ... number = 1, 558 ... ) 559 "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'" 560 561 """ 562 from meerschaum.utils.packages import attempt_import 563 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type 564 dateutil_parser = attempt_import('dateutil.parser') 565 if 'int' in str(type(begin)).lower(): 566 return str(begin) 567 if not begin: 568 return '' 569 570 _original_begin = begin 571 begin_time = None 572 ### Sanity check: make sure `begin` is a valid datetime before we inject anything. 573 if not isinstance(begin, datetime): 574 try: 575 begin_time = dateutil_parser.parse(begin) 576 except Exception: 577 begin_time = None 578 else: 579 begin_time = begin 580 581 ### Unable to parse into a datetime. 582 if begin_time is None: 583 ### Throw an error if banned symbols are included in the `begin` string. 584 clean(str(begin)) 585 ### If begin is a valid datetime, wrap it in quotes. 586 else: 587 if isinstance(begin, datetime) and begin.tzinfo is not None: 588 begin = begin.astimezone(timezone.utc) 589 begin = ( 590 f"'{begin.replace(tzinfo=None)}'" 591 if isinstance(begin, datetime) and flavor in TIMEZONE_NAIVE_FLAVORS 592 else f"'{begin}'" 593 ) 594 595 dt_is_utc = begin_time.tzinfo is not None if begin_time is not None else '+' in str(begin) 596 db_type = get_db_type_from_pd_type( 597 ('datetime64[ns, UTC]' if dt_is_utc else 'datetime64[ns]'), 598 flavor=flavor, 599 ) 600 601 da = "" 602 if flavor in ('postgresql', 'timescaledb', 'cockroachdb', 'citus'): 603 begin = ( 604 f"CAST({begin} AS {db_type})" if begin != 'now' 605 else "CAST(NOW() AT TIME ZONE 'utc' AS {db_type})" 606 ) 607 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 608 609 elif flavor == 'duckdb': 610 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'NOW()' 611 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 612 613 elif flavor in ('mssql',): 614 if begin_time and begin_time.microsecond != 0 and not dt_is_utc: 615 begin = begin[:-4] + "'" 616 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'GETUTCDATE()' 617 da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin 618 619 elif flavor in ('mysql', 'mariadb'): 620 begin = f"CAST({begin} AS DATETIME(6))" if begin != 'now' else 'UTC_TIMESTAMP(6)' 621 da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin) 622 623 elif flavor == 'sqlite': 624 da = f"datetime({begin}, '{number} {datepart}')" 625 626 elif flavor == 'oracle': 627 if begin == 'now': 628 begin = str( 629 datetime.now(timezone.utc).replace(tzinfo=None).strftime('%Y:%m:%d %M:%S.%f') 630 ) 631 elif begin_time: 632 begin = str(begin_time.strftime('%Y-%m-%d %H:%M:%S.%f')) 633 dt_format = 'YYYY-MM-DD HH24:MI:SS.FF' 634 _begin = f"'{begin}'" if begin_time else begin 635 da = ( 636 (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin) 637 + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "") 638 ) 639 return da
Generate a DATEADD
clause depending on database flavor.
Parameters
flavor (str, default
'postgresql'
): SQL database flavor, e.g.'postgresql'
,'sqlite'
.Currently supported flavors:
'postgresql'
'timescaledb'
'citus'
'cockroachdb'
'duckdb'
'mssql'
'mysql'
'mariadb'
'sqlite'
'oracle'
datepart (str, default
'day'
): Which part of the date to modify. Supported values:'year'
'month'
'day'
'hour'
'minute'
'second'
- number (Union[int, float], default
0
): How many units to add to the date part. - begin (Union[str, datetime], default
'now'
): Base datetime to which to add dateparts.
Returns
- The appropriate
DATEADD
string for the corresponding database flavor.
Examples
>>> dateadd_str(
... flavor = 'mssql',
... begin = datetime(2022, 1, 1, 0, 0),
... number = 1,
... )
"DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME))"
>>> dateadd_str(
... flavor = 'postgresql',
... begin = datetime(2022, 1, 1, 0, 0),
... number = 1,
... )
"CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'"
642def test_connection( 643 self, 644 **kw: Any 645) -> Union[bool, None]: 646 """ 647 Test if a successful connection to the database may be made. 648 649 Parameters 650 ---------- 651 **kw: 652 The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`. 653 654 Returns 655 ------- 656 `True` if a connection is made, otherwise `False` or `None` in case of failure. 657 658 """ 659 import warnings 660 from meerschaum.connectors.poll import retry_connect 661 _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self} 662 _default_kw.update(kw) 663 with warnings.catch_warnings(): 664 warnings.filterwarnings('ignore', 'Could not') 665 try: 666 return retry_connect(**_default_kw) 667 except Exception as e: 668 return False
Test if a successful connection to the database may be made.
Parameters
- **kw:: The keyword arguments are passed to
meerschaum.connectors.poll.retry_connect
.
Returns
True
if a connection is made, otherwiseFalse
orNone
in case of failure.
671def get_distinct_col_count( 672 col: str, 673 query: str, 674 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 675 debug: bool = False 676) -> Optional[int]: 677 """ 678 Returns the number of distinct items in a column of a SQL query. 679 680 Parameters 681 ---------- 682 col: str: 683 The column in the query to count. 684 685 query: str: 686 The SQL query to count from. 687 688 connector: Optional[mrsm.connectors.sql.SQLConnector], default None: 689 The SQLConnector to execute the query. 690 691 debug: bool, default False: 692 Verbosity toggle. 693 694 Returns 695 ------- 696 An `int` of the number of columns in the query or `None` if the query fails. 697 698 """ 699 if connector is None: 700 connector = mrsm.get_connector('sql') 701 702 _col_name = sql_item_name(col, connector.flavor, None) 703 704 _meta_query = ( 705 f""" 706 WITH src AS ( {query} ), 707 dist AS ( SELECT DISTINCT {_col_name} FROM src ) 708 SELECT COUNT(*) FROM dist""" 709 ) if connector.flavor not in ('mysql', 'mariadb') else ( 710 f""" 711 SELECT COUNT(*) 712 FROM ( 713 SELECT DISTINCT {_col_name} 714 FROM ({query}) AS src 715 ) AS dist""" 716 ) 717 718 result = connector.value(_meta_query, debug=debug) 719 try: 720 return int(result) 721 except Exception as e: 722 return None
Returns the number of distinct items in a column of a SQL query.
Parameters
- col (str:): The column in the query to count.
- query (str:): The SQL query to count from.
- connector (Optional[mrsm.connectors.sql.SQLConnector], default None:): The SQLConnector to execute the query.
- debug (bool, default False:): Verbosity toggle.
Returns
- An
int
of the number of columns in the query orNone
if the query fails.
725def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str: 726 """ 727 Parse SQL items depending on the flavor. 728 729 Parameters 730 ---------- 731 item: str : 732 The database item (table, view, etc.) in need of quotes. 733 734 flavor: str : 735 The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.). 736 737 Returns 738 ------- 739 A `str` which contains the input `item` wrapped in the corresponding escape characters. 740 741 Examples 742 -------- 743 >>> sql_item_name('table', 'sqlite') 744 '"table"' 745 >>> sql_item_name('table', 'mssql') 746 "[table]" 747 >>> sql_item_name('table', 'postgresql', schema='abc') 748 '"abc"."table"' 749 750 """ 751 truncated_item = truncate_item_name(str(item), flavor) 752 if flavor == 'oracle': 753 truncated_item = pg_capital(truncated_item) 754 ### NOTE: System-reserved words must be quoted. 755 if truncated_item.lower() in ( 756 'float', 'varchar', 'nvarchar', 'clob', 757 'boolean', 'integer', 'table', 758 ): 759 wrappers = ('"', '"') 760 else: 761 wrappers = ('', '') 762 else: 763 wrappers = table_wrappers.get(flavor, table_wrappers['default']) 764 765 ### NOTE: SQLite does not support schemas. 766 if flavor == 'sqlite': 767 schema = None 768 769 schema_prefix = ( 770 (wrappers[0] + schema + wrappers[1] + '.') 771 if schema is not None 772 else '' 773 ) 774 775 return schema_prefix + wrappers[0] + truncated_item + wrappers[1]
Parse SQL items depending on the flavor.
Parameters
- item (str :): The database item (table, view, etc.) in need of quotes.
- flavor (str :):
The database flavor (
'postgresql'
,'mssql'
,'sqllite'
, etc.).
Returns
- A
str
which contains the inputitem
wrapped in the corresponding escape characters.
Examples
>>> sql_item_name('table', 'sqlite')
'"table"'
>>> sql_item_name('table', 'mssql')
"[table]"
>>> sql_item_name('table', 'postgresql', schema='abc')
'"abc"."table"'
778def pg_capital(s: str) -> str: 779 """ 780 If string contains a capital letter, wrap it in double quotes. 781 782 Parameters 783 ---------- 784 s: str : 785 The string to be escaped. 786 787 Returns 788 ------- 789 The input string wrapped in quotes only if it needs them. 790 791 Examples 792 -------- 793 >>> pg_capital("My Table") 794 '"My Table"' 795 >>> pg_capital('my_table') 796 'my_table' 797 798 """ 799 if '"' in s: 800 return s 801 needs_quotes = s.startswith('_') 802 for c in str(s): 803 if ord(c) < ord('a') or ord(c) > ord('z'): 804 if not c.isdigit() and c != '_': 805 needs_quotes = True 806 break 807 if needs_quotes: 808 return '"' + s + '"' 809 return s
If string contains a capital letter, wrap it in double quotes.
Parameters
s (str :):
The string to be escaped.
Returns
- The input string wrapped in quotes only if it needs them.
Examples
>>> pg_capital("My Table")
'"My Table"'
>>> pg_capital('my_table')
'my_table'
812def oracle_capital(s: str) -> str: 813 """ 814 Capitalize the string of an item on an Oracle database. 815 """ 816 return s
Capitalize the string of an item on an Oracle database.
819def truncate_item_name(item: str, flavor: str) -> str: 820 """ 821 Truncate item names to stay within the database flavor's character limit. 822 823 Parameters 824 ---------- 825 item: str 826 The database item being referenced. This string is the "canonical" name internally. 827 828 flavor: str 829 The flavor of the database on which `item` resides. 830 831 Returns 832 ------- 833 The truncated string. 834 """ 835 from meerschaum.utils.misc import truncate_string_sections 836 return truncate_string_sections( 837 item, max_len=max_name_lens.get(flavor, max_name_lens['default']) 838 )
Truncate item names to stay within the database flavor's character limit.
Parameters
- item (str): The database item being referenced. This string is the "canonical" name internally.
- flavor (str):
The flavor of the database on which
item
resides.
Returns
- The truncated string.
841def build_where( 842 params: Dict[str, Any], 843 connector: Optional[meerschaum.connectors.sql.SQLConnector] = None, 844 with_where: bool = True, 845) -> str: 846 """ 847 Build the `WHERE` clause based on the input criteria. 848 849 Parameters 850 ---------- 851 params: Dict[str, Any]: 852 The keywords dictionary to convert into a WHERE clause. 853 If a value is a string which begins with an underscore, negate that value 854 (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`). 855 A value of `_None` will be interpreted as `IS NOT NULL`. 856 857 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 858 The Meerschaum SQLConnector that will be executing the query. 859 The connector is used to extract the SQL dialect. 860 861 with_where: bool, default True: 862 If `True`, include the leading `'WHERE'` string. 863 864 Returns 865 ------- 866 A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor. 867 868 Examples 869 -------- 870 ``` 871 >>> print(build_where({'foo': [1, 2, 3]})) 872 873 WHERE 874 "foo" IN ('1', '2', '3') 875 ``` 876 """ 877 import json 878 from meerschaum.config.static import STATIC_CONFIG 879 from meerschaum.utils.warnings import warn 880 from meerschaum.utils.dtypes import value_is_null, none_if_null 881 negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix'] 882 try: 883 params_json = json.dumps(params) 884 except Exception as e: 885 params_json = str(params) 886 bad_words = ['drop ', '--', ';'] 887 for word in bad_words: 888 if word in params_json.lower(): 889 warn(f"Aborting build_where() due to possible SQL injection.") 890 return '' 891 892 if connector is None: 893 from meerschaum import get_connector 894 connector = get_connector('sql') 895 where = "" 896 leading_and = "\n AND " 897 for key, value in params.items(): 898 _key = sql_item_name(key, connector.flavor, None) 899 ### search across a list (i.e. IN syntax) 900 if isinstance(value, Iterable) and not isinstance(value, (dict, str)): 901 includes = [ 902 none_if_null(item) 903 for item in value 904 if not str(item).startswith(negation_prefix) 905 ] 906 null_includes = [item for item in includes if item is None] 907 not_null_includes = [item for item in includes if item is not None] 908 excludes = [ 909 none_if_null(str(item)[len(negation_prefix):]) 910 for item in value 911 if str(item).startswith(negation_prefix) 912 ] 913 null_excludes = [item for item in excludes if item is None] 914 not_null_excludes = [item for item in excludes if item is not None] 915 916 if includes: 917 where += f"{leading_and}(" 918 if not_null_includes: 919 where += f"{_key} IN (" 920 for item in not_null_includes: 921 quoted_item = str(item).replace("'", "''") 922 where += f"'{quoted_item}', " 923 where = where[:-2] + ")" 924 if null_includes: 925 where += ("\n OR " if not_null_includes else "") + f"{_key} IS NULL" 926 if includes: 927 where += ")" 928 929 if excludes: 930 where += f"{leading_and}(" 931 if not_null_excludes: 932 where += f"{_key} NOT IN (" 933 for item in not_null_excludes: 934 quoted_item = str(item).replace("'", "''") 935 where += f"'{quoted_item}', " 936 where = where[:-2] + ")" 937 if null_excludes: 938 where += ("\n AND " if not_null_excludes else "") + f"{_key} IS NOT NULL" 939 if excludes: 940 where += ")" 941 942 continue 943 944 ### search a dictionary 945 elif isinstance(value, dict): 946 import json 947 where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'") 948 continue 949 950 eq_sign = '=' 951 is_null = 'IS NULL' 952 if value_is_null(str(value).lstrip(negation_prefix)): 953 value = ( 954 (negation_prefix + 'None') 955 if str(value).startswith(negation_prefix) 956 else None 957 ) 958 if str(value).startswith(negation_prefix): 959 value = str(value)[len(negation_prefix):] 960 eq_sign = '!=' 961 if value_is_null(value): 962 value = None 963 is_null = 'IS NOT NULL' 964 quoted_value = str(value).replace("'", "''") 965 where += ( 966 f"{leading_and}{_key} " 967 + (is_null if value is None else f"{eq_sign} '{quoted_value}'") 968 ) 969 970 if len(where) > 1: 971 where = ("\nWHERE\n " if with_where else '') + where[len(leading_and):] 972 return where
Build the WHERE
clause based on the input criteria.
Parameters
- params (Dict[str, Any]:):
The keywords dictionary to convert into a WHERE clause.
If a value is a string which begins with an underscore, negate that value
(e.g.
!=
instead of=
orNOT IN
instead ofIN
). A value of_None
will be interpreted asIS NOT NULL
. - connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The Meerschaum SQLConnector that will be executing the query. The connector is used to extract the SQL dialect.
- with_where (bool, default True:):
If
True
, include the leading'WHERE'
string.
Returns
- A
str
of theWHERE
clause from the inputparams
dictionary for the connector's flavor.
Examples
>>> print(build_where({'foo': [1, 2, 3]}))
WHERE
"foo" IN ('1', '2', '3')
975def table_exists( 976 table: str, 977 connector: mrsm.connectors.sql.SQLConnector, 978 schema: Optional[str] = None, 979 debug: bool = False, 980) -> bool: 981 """Check if a table exists. 982 983 Parameters 984 ---------- 985 table: str: 986 The name of the table in question. 987 988 connector: mrsm.connectors.sql.SQLConnector 989 The connector to the database which holds the table. 990 991 schema: Optional[str], default None 992 Optionally specify the table schema. 993 Defaults to `connector.schema`. 994 995 debug: bool, default False : 996 Verbosity toggle. 997 998 Returns 999 ------- 1000 A `bool` indicating whether or not the table exists on the database. 1001 """ 1002 sqlalchemy = mrsm.attempt_import('sqlalchemy') 1003 schema = schema or connector.schema 1004 insp = sqlalchemy.inspect(connector.engine) 1005 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1006 return insp.has_table(truncated_table_name, schema=schema)
Check if a table exists.
Parameters
- table (str:): The name of the table in question.
- connector (mrsm.connectors.sql.SQLConnector): The connector to the database which holds the table.
- schema (Optional[str], default None):
Optionally specify the table schema.
Defaults to
connector.schema
. - debug (bool, default False :): Verbosity toggle.
Returns
- A
bool
indicating whether or not the table exists on the database.
1009def get_sqlalchemy_table( 1010 table: str, 1011 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 1012 schema: Optional[str] = None, 1013 refresh: bool = False, 1014 debug: bool = False, 1015) -> Union['sqlalchemy.Table', None]: 1016 """ 1017 Construct a SQLAlchemy table from its name. 1018 1019 Parameters 1020 ---------- 1021 table: str 1022 The name of the table on the database. Does not need to be escaped. 1023 1024 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 1025 The connector to the database which holds the table. 1026 1027 schema: Optional[str], default None 1028 Specify on which schema the table resides. 1029 Defaults to the schema set in `connector`. 1030 1031 refresh: bool, default False 1032 If `True`, rebuild the cached table object. 1033 1034 debug: bool, default False: 1035 Verbosity toggle. 1036 1037 Returns 1038 ------- 1039 A `sqlalchemy.Table` object for the table. 1040 1041 """ 1042 if connector is None: 1043 from meerschaum import get_connector 1044 connector = get_connector('sql') 1045 1046 if connector.flavor == 'duckdb': 1047 return None 1048 1049 from meerschaum.connectors.sql.tables import get_tables 1050 from meerschaum.utils.packages import attempt_import 1051 from meerschaum.utils.warnings import warn 1052 if refresh: 1053 connector.metadata.clear() 1054 tables = get_tables(mrsm_instance=connector, debug=debug, create=False) 1055 sqlalchemy = attempt_import('sqlalchemy') 1056 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1057 table_kwargs = { 1058 'autoload_with': connector.engine, 1059 } 1060 if schema: 1061 table_kwargs['schema'] = schema 1062 1063 if refresh or truncated_table_name not in tables: 1064 try: 1065 tables[truncated_table_name] = sqlalchemy.Table( 1066 truncated_table_name, 1067 connector.metadata, 1068 **table_kwargs 1069 ) 1070 except sqlalchemy.exc.NoSuchTableError as e: 1071 warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.") 1072 return None 1073 return tables[truncated_table_name]
Construct a SQLAlchemy table from its name.
Parameters
- table (str): The name of the table on the database. Does not need to be escaped.
- connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The connector to the database which holds the table.
- schema (Optional[str], default None):
Specify on which schema the table resides.
Defaults to the schema set in
connector
. - refresh (bool, default False):
If
True
, rebuild the cached table object. - debug (bool, default False:): Verbosity toggle.
Returns
- A
sqlalchemy.Table
object for the table.
1076def get_table_cols_types( 1077 table: str, 1078 connectable: Union[ 1079 'mrsm.connectors.sql.SQLConnector', 1080 'sqlalchemy.orm.session.Session', 1081 'sqlalchemy.engine.base.Engine' 1082 ], 1083 flavor: Optional[str] = None, 1084 schema: Optional[str] = None, 1085 database: Optional[str] = None, 1086 debug: bool = False, 1087) -> Dict[str, str]: 1088 """ 1089 Return a dictionary mapping a table's columns to data types. 1090 This is useful for inspecting tables creating during a not-yet-committed session. 1091 1092 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1093 Use this function if you are confident the table name is unique or if you have 1094 and explicit schema. 1095 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1096 1097 Parameters 1098 ---------- 1099 table: str 1100 The name of the table (unquoted). 1101 1102 connectable: Union[ 1103 'mrsm.connectors.sql.SQLConnector', 1104 'sqlalchemy.orm.session.Session', 1105 'sqlalchemy.engine.base.Engine' 1106 ] 1107 The connection object used to fetch the columns and types. 1108 1109 flavor: Optional[str], default None 1110 The database dialect flavor to use for the query. 1111 If omitted, default to `connectable.flavor`. 1112 1113 schema: Optional[str], default None 1114 If provided, restrict the query to this schema. 1115 1116 database: Optional[str]. default None 1117 If provided, restrict the query to this database. 1118 1119 Returns 1120 ------- 1121 A dictionary mapping column names to data types. 1122 """ 1123 from meerschaum.connectors import SQLConnector 1124 sqlalchemy = mrsm.attempt_import('sqlalchemy') 1125 flavor = flavor or getattr(connectable, 'flavor', None) 1126 if not flavor: 1127 raise ValueError("Please provide a database flavor.") 1128 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1129 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1130 if flavor in NO_SCHEMA_FLAVORS: 1131 schema = None 1132 if schema is None: 1133 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1134 if flavor in ('sqlite', 'duckdb', 'oracle'): 1135 database = None 1136 table_trunc = truncate_item_name(table, flavor=flavor) 1137 table_lower = table.lower() 1138 table_upper = table.upper() 1139 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1140 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1141 db_prefix = ( 1142 "tempdb." 1143 if flavor == 'mssql' and table.startswith('#') 1144 else "" 1145 ) 1146 1147 cols_types_query = sqlalchemy.text( 1148 columns_types_queries.get( 1149 flavor, 1150 columns_types_queries['default'] 1151 ).format( 1152 table=table, 1153 table_trunc=table_trunc, 1154 table_lower=table_lower, 1155 table_lower_trunc=table_lower_trunc, 1156 table_upper=table_upper, 1157 table_upper_trunc=table_upper_trunc, 1158 db_prefix=db_prefix, 1159 ) 1160 ) 1161 1162 cols = ['database', 'schema', 'table', 'column', 'type'] 1163 result_cols_ix = dict(enumerate(cols)) 1164 1165 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1166 if not debug_kwargs and debug: 1167 dprint(cols_types_query) 1168 1169 try: 1170 result_rows = ( 1171 [ 1172 row 1173 for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall() 1174 ] 1175 if flavor != 'duckdb' 1176 else [ 1177 tuple([doc[col] for col in cols]) 1178 for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records') 1179 ] 1180 ) 1181 cols_types_docs = [ 1182 { 1183 result_cols_ix[i]: val 1184 for i, val in enumerate(row) 1185 } 1186 for row in result_rows 1187 ] 1188 cols_types_docs_filtered = [ 1189 doc 1190 for doc in cols_types_docs 1191 if ( 1192 ( 1193 not schema 1194 or doc['schema'] == schema 1195 ) 1196 and 1197 ( 1198 not database 1199 or doc['database'] == database 1200 ) 1201 ) 1202 ] 1203 1204 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1205 if cols_types_docs and not cols_types_docs_filtered: 1206 cols_types_docs_filtered = cols_types_docs 1207 1208 return { 1209 ( 1210 doc['column'] 1211 if flavor != 'oracle' else ( 1212 ( 1213 doc['column'].lower() 1214 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1215 else doc['column'] 1216 ) 1217 ) 1218 ): doc['type'].upper() 1219 for doc in cols_types_docs_filtered 1220 } 1221 except Exception as e: 1222 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1223 return {}
Return a dictionary mapping a table's columns to data types. This is useful for inspecting tables creating during a not-yet-committed session.
NOTE: This may return incorrect columns if the schema is not explicitly stated.
Use this function if you are confident the table name is unique or if you have
and explicit schema.
To use the configured schema, get the columns from get_sqlalchemy_table()
instead.
Parameters
- table (str): The name of the table (unquoted).
- connectable (Union[): 'mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine'
- ]: The connection object used to fetch the columns and types.
- flavor (Optional[str], default None):
The database dialect flavor to use for the query.
If omitted, default to
connectable.flavor
. - schema (Optional[str], default None): If provided, restrict the query to this schema.
- database (Optional[str]. default None): If provided, restrict the query to this database.
Returns
- A dictionary mapping column names to data types.
1226def get_table_cols_indices( 1227 table: str, 1228 connectable: Union[ 1229 'mrsm.connectors.sql.SQLConnector', 1230 'sqlalchemy.orm.session.Session', 1231 'sqlalchemy.engine.base.Engine' 1232 ], 1233 flavor: Optional[str] = None, 1234 schema: Optional[str] = None, 1235 database: Optional[str] = None, 1236 debug: bool = False, 1237) -> Dict[str, List[str]]: 1238 """ 1239 Return a dictionary mapping a table's columns to lists of indices. 1240 This is useful for inspecting tables creating during a not-yet-committed session. 1241 1242 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1243 Use this function if you are confident the table name is unique or if you have 1244 and explicit schema. 1245 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1246 1247 Parameters 1248 ---------- 1249 table: str 1250 The name of the table (unquoted). 1251 1252 connectable: Union[ 1253 'mrsm.connectors.sql.SQLConnector', 1254 'sqlalchemy.orm.session.Session', 1255 'sqlalchemy.engine.base.Engine' 1256 ] 1257 The connection object used to fetch the columns and types. 1258 1259 flavor: Optional[str], default None 1260 The database dialect flavor to use for the query. 1261 If omitted, default to `connectable.flavor`. 1262 1263 schema: Optional[str], default None 1264 If provided, restrict the query to this schema. 1265 1266 database: Optional[str]. default None 1267 If provided, restrict the query to this database. 1268 1269 Returns 1270 ------- 1271 A dictionary mapping column names to a list of indices. 1272 """ 1273 from collections import defaultdict 1274 from meerschaum.connectors import SQLConnector 1275 sqlalchemy = mrsm.attempt_import('sqlalchemy') 1276 flavor = flavor or getattr(connectable, 'flavor', None) 1277 if not flavor: 1278 raise ValueError("Please provide a database flavor.") 1279 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1280 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1281 if flavor in NO_SCHEMA_FLAVORS: 1282 schema = None 1283 if schema is None: 1284 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1285 if flavor in ('sqlite', 'duckdb', 'oracle'): 1286 database = None 1287 table_trunc = truncate_item_name(table, flavor=flavor) 1288 table_lower = table.lower() 1289 table_upper = table.upper() 1290 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1291 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1292 db_prefix = ( 1293 "tempdb." 1294 if flavor == 'mssql' and table.startswith('#') 1295 else "" 1296 ) 1297 1298 cols_indices_query = sqlalchemy.text( 1299 columns_indices_queries.get( 1300 flavor, 1301 columns_indices_queries['default'] 1302 ).format( 1303 table=table, 1304 table_trunc=table_trunc, 1305 table_lower=table_lower, 1306 table_lower_trunc=table_lower_trunc, 1307 table_upper=table_upper, 1308 table_upper_trunc=table_upper_trunc, 1309 db_prefix=db_prefix, 1310 schema=schema, 1311 ) 1312 ) 1313 1314 cols = ['database', 'schema', 'table', 'column', 'index', 'index_type'] 1315 result_cols_ix = dict(enumerate(cols)) 1316 1317 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1318 if not debug_kwargs and debug: 1319 dprint(cols_indices_query) 1320 1321 try: 1322 result_rows = ( 1323 [ 1324 row 1325 for row in connectable.execute(cols_indices_query, **debug_kwargs).fetchall() 1326 ] 1327 if flavor != 'duckdb' 1328 else [ 1329 tuple([doc[col] for col in cols]) 1330 for doc in connectable.read(cols_indices_query, debug=debug).to_dict(orient='records') 1331 ] 1332 ) 1333 cols_types_docs = [ 1334 { 1335 result_cols_ix[i]: val 1336 for i, val in enumerate(row) 1337 } 1338 for row in result_rows 1339 ] 1340 cols_types_docs_filtered = [ 1341 doc 1342 for doc in cols_types_docs 1343 if ( 1344 ( 1345 not schema 1346 or doc['schema'] == schema 1347 ) 1348 and 1349 ( 1350 not database 1351 or doc['database'] == database 1352 ) 1353 ) 1354 ] 1355 1356 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1357 if cols_types_docs and not cols_types_docs_filtered: 1358 cols_types_docs_filtered = cols_types_docs 1359 1360 cols_indices = defaultdict(lambda: []) 1361 for doc in cols_types_docs_filtered: 1362 col = ( 1363 doc['column'] 1364 if flavor != 'oracle' 1365 else ( 1366 doc['column'].lower() 1367 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1368 else doc['column'] 1369 ) 1370 ) 1371 cols_indices[col].append( 1372 { 1373 'name': doc.get('index', None), 1374 'type': doc.get('index_type', None), 1375 } 1376 ) 1377 1378 return dict(cols_indices) 1379 except Exception as e: 1380 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1381 return {}
Return a dictionary mapping a table's columns to lists of indices. This is useful for inspecting tables creating during a not-yet-committed session.
NOTE: This may return incorrect columns if the schema is not explicitly stated.
Use this function if you are confident the table name is unique or if you have
and explicit schema.
To use the configured schema, get the columns from get_sqlalchemy_table()
instead.
Parameters
- table (str): The name of the table (unquoted).
- connectable (Union[): 'mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine'
- ]: The connection object used to fetch the columns and types.
- flavor (Optional[str], default None):
The database dialect flavor to use for the query.
If omitted, default to
connectable.flavor
. - schema (Optional[str], default None): If provided, restrict the query to this schema.
- database (Optional[str]. default None): If provided, restrict the query to this database.
Returns
- A dictionary mapping column names to a list of indices.
1384def get_update_queries( 1385 target: str, 1386 patch: str, 1387 connectable: Union[ 1388 mrsm.connectors.sql.SQLConnector, 1389 'sqlalchemy.orm.session.Session' 1390 ], 1391 join_cols: Iterable[str], 1392 flavor: Optional[str] = None, 1393 upsert: bool = False, 1394 datetime_col: Optional[str] = None, 1395 schema: Optional[str] = None, 1396 patch_schema: Optional[str] = None, 1397 debug: bool = False, 1398) -> List[str]: 1399 """ 1400 Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table. 1401 1402 Parameters 1403 ---------- 1404 target: str 1405 The name of the target table. 1406 1407 patch: str 1408 The name of the patch table. This should have the same shape as the target. 1409 1410 connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session] 1411 The `SQLConnector` or SQLAlchemy session which will later execute the queries. 1412 1413 join_cols: List[str] 1414 The columns to use to join the patch to the target. 1415 1416 flavor: Optional[str], default None 1417 If using a SQLAlchemy session, provide the expected database flavor. 1418 1419 upsert: bool, default False 1420 If `True`, return an upsert query rather than an update. 1421 1422 datetime_col: Optional[str], default None 1423 If provided, bound the join query using this column as the datetime index. 1424 This must be present on both tables. 1425 1426 schema: Optional[str], default None 1427 If provided, use this schema when quoting the target table. 1428 Defaults to `connector.schema`. 1429 1430 patch_schema: Optional[str], default None 1431 If provided, use this schema when quoting the patch table. 1432 Defaults to `schema`. 1433 1434 debug: bool, default False 1435 Verbosity toggle. 1436 1437 Returns 1438 ------- 1439 A list of query strings to perform the update operation. 1440 """ 1441 from meerschaum.connectors import SQLConnector 1442 from meerschaum.utils.debug import dprint 1443 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1444 flavor = flavor or (connectable.flavor if isinstance(connectable, SQLConnector) else None) 1445 if not flavor: 1446 raise ValueError("Provide a flavor if using a SQLAlchemy session.") 1447 if ( 1448 flavor == 'sqlite' 1449 and isinstance(connectable, SQLConnector) 1450 and connectable.db_version < '3.33.0' 1451 ): 1452 flavor = 'sqlite_delete_insert' 1453 flavor_key = (f'{flavor}-upsert' if upsert else flavor) 1454 base_queries = update_queries.get( 1455 flavor_key, 1456 update_queries['default'] 1457 ) 1458 if not isinstance(base_queries, list): 1459 base_queries = [base_queries] 1460 schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None) 1461 patch_schema = patch_schema or schema 1462 target_table_columns = get_table_cols_types( 1463 target, 1464 connectable, 1465 flavor=flavor, 1466 schema=schema, 1467 debug=debug, 1468 ) 1469 patch_table_columns = get_table_cols_types( 1470 patch, 1471 connectable, 1472 flavor=flavor, 1473 schema=patch_schema, 1474 debug=debug, 1475 ) 1476 1477 patch_cols_str = ', '.join( 1478 [ 1479 sql_item_name(col, flavor) 1480 for col in patch_table_columns 1481 ] 1482 ) 1483 patch_cols_prefixed_str = ', '.join( 1484 [ 1485 'p.' + sql_item_name(col, flavor) 1486 for col in patch_table_columns 1487 ] 1488 ) 1489 1490 join_cols_str = ', '.join( 1491 [ 1492 sql_item_name(col, flavor) 1493 for col in join_cols 1494 ] 1495 ) 1496 1497 value_cols = [] 1498 join_cols_types = [] 1499 if debug: 1500 dprint("target_table_columns:") 1501 mrsm.pprint(target_table_columns) 1502 for c_name, c_type in target_table_columns.items(): 1503 if c_name not in patch_table_columns: 1504 continue 1505 if flavor in DB_FLAVORS_CAST_DTYPES: 1506 c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type) 1507 ( 1508 join_cols_types 1509 if c_name in join_cols 1510 else value_cols 1511 ).append((c_name, c_type)) 1512 if debug: 1513 dprint(f"value_cols: {value_cols}") 1514 1515 if not join_cols_types: 1516 return [] 1517 if not value_cols and not upsert: 1518 return [] 1519 1520 coalesce_join_cols_str = ', '.join( 1521 [ 1522 'COALESCE(' 1523 + sql_item_name(c_name, flavor) 1524 + ', ' 1525 + get_null_replacement(c_type, flavor) 1526 + ')' 1527 for c_name, c_type in join_cols_types 1528 ] 1529 ) 1530 1531 update_or_nothing = ('UPDATE' if value_cols else 'NOTHING') 1532 1533 def sets_subquery(l_prefix: str, r_prefix: str): 1534 if not value_cols: 1535 return '' 1536 return 'SET ' + ',\n'.join([ 1537 ( 1538 l_prefix + sql_item_name(c_name, flavor, None) 1539 + ' = ' 1540 + ('CAST(' if flavor != 'sqlite' else '') 1541 + r_prefix 1542 + sql_item_name(c_name, flavor, None) 1543 + (' AS ' if flavor != 'sqlite' else '') 1544 + (c_type.replace('_', ' ') if flavor != 'sqlite' else '') 1545 + (')' if flavor != 'sqlite' else '') 1546 ) for c_name, c_type in value_cols 1547 ]) 1548 1549 def and_subquery(l_prefix: str, r_prefix: str): 1550 return '\nAND\n'.join([ 1551 ( 1552 "COALESCE(" 1553 + l_prefix 1554 + sql_item_name(c_name, flavor, None) 1555 + ", " 1556 + get_null_replacement(c_type, flavor) 1557 + ")" 1558 + ' = ' 1559 + "COALESCE(" 1560 + r_prefix 1561 + sql_item_name(c_name, flavor, None) 1562 + ", " 1563 + get_null_replacement(c_type, flavor) 1564 + ")" 1565 ) for c_name, c_type in join_cols_types 1566 ]) 1567 1568 target_table_name = sql_item_name(target, flavor, schema) 1569 patch_table_name = sql_item_name(patch, flavor, patch_schema) 1570 dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None 1571 date_bounds_subquery = ( 1572 f""" 1573 f.{dt_col_name} >= (SELECT MIN({dt_col_name}) FROM {patch_table_name}) 1574 AND f.{dt_col_name} <= (SELECT MAX({dt_col_name}) FROM {patch_table_name}) 1575 """ 1576 if datetime_col 1577 else "1 = 1" 1578 ) 1579 1580 ### NOTE: MSSQL upserts must exclude the update portion if only upserting indices. 1581 when_matched_update_sets_subquery_none = "" if not value_cols else ( 1582 "WHEN MATCHED THEN" 1583 f" UPDATE {sets_subquery('', 'p.')}" 1584 ) 1585 1586 cols_equal_values = '\n,'.join( 1587 [ 1588 f"{sql_item_name(c_name, flavor)} = VALUES({sql_item_name(c_name, flavor)})" 1589 for c_name, c_type in value_cols 1590 ] 1591 ) 1592 on_duplicate_key_update = ( 1593 "ON DUPLICATE KEY UPDATE" 1594 if value_cols 1595 else "" 1596 ) 1597 ignore = "IGNORE " if not value_cols else "" 1598 1599 return [ 1600 base_query.format( 1601 sets_subquery_none=sets_subquery('', 'p.'), 1602 sets_subquery_none_excluded=sets_subquery('', 'EXCLUDED.'), 1603 sets_subquery_f=sets_subquery('f.', 'p.'), 1604 and_subquery_f=and_subquery('p.', 'f.'), 1605 and_subquery_t=and_subquery('p.', 't.'), 1606 target_table_name=target_table_name, 1607 patch_table_name=patch_table_name, 1608 patch_cols_str=patch_cols_str, 1609 patch_cols_prefixed_str=patch_cols_prefixed_str, 1610 date_bounds_subquery=date_bounds_subquery, 1611 join_cols_str=join_cols_str, 1612 coalesce_join_cols_str=coalesce_join_cols_str, 1613 update_or_nothing=update_or_nothing, 1614 when_matched_update_sets_subquery_none=when_matched_update_sets_subquery_none, 1615 cols_equal_values=cols_equal_values, 1616 on_duplicate_key_update=on_duplicate_key_update, 1617 ignore=ignore, 1618 ) 1619 for base_query in base_queries 1620 ]
Build a list of MERGE
, UPDATE
, DELETE
/INSERT
queries to apply a patch to target table.
Parameters
- target (str): The name of the target table.
- patch (str): The name of the patch table. This should have the same shape as the target.
- connectable (Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session]):
The
SQLConnector
or SQLAlchemy session which will later execute the queries. - join_cols (List[str]): The columns to use to join the patch to the target.
- flavor (Optional[str], default None): If using a SQLAlchemy session, provide the expected database flavor.
- upsert (bool, default False):
If
True
, return an upsert query rather than an update. - datetime_col (Optional[str], default None): If provided, bound the join query using this column as the datetime index. This must be present on both tables.
- schema (Optional[str], default None):
If provided, use this schema when quoting the target table.
Defaults to
connector.schema
. - patch_schema (Optional[str], default None):
If provided, use this schema when quoting the patch table.
Defaults to
schema
. - debug (bool, default False): Verbosity toggle.
Returns
- A list of query strings to perform the update operation.
1623def get_null_replacement(typ: str, flavor: str) -> str: 1624 """ 1625 Return a value that may temporarily be used in place of NULL for this type. 1626 1627 Parameters 1628 ---------- 1629 typ: str 1630 The typ to be converted to NULL. 1631 1632 flavor: str 1633 The database flavor for which this value will be used. 1634 1635 Returns 1636 ------- 1637 A value which may stand in place of NULL for this type. 1638 `'None'` is returned if a value cannot be determined. 1639 """ 1640 from meerschaum.utils.dtypes import are_dtypes_equal 1641 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1642 if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'): 1643 return '-987654321' 1644 if 'bool' in typ.lower() or typ.lower() == 'bit': 1645 bool_typ = ( 1646 PD_TO_DB_DTYPES_FLAVORS 1647 .get('bool', {}) 1648 .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default']) 1649 ) 1650 if flavor in DB_FLAVORS_CAST_DTYPES: 1651 bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ) 1652 val_to_cast = ( 1653 -987654321 1654 if flavor in ('mysql', 'mariadb') 1655 else 0 1656 ) 1657 return f'CAST({val_to_cast} AS {bool_typ})' 1658 if 'time' in typ.lower() or 'date' in typ.lower(): 1659 return dateadd_str(flavor=flavor, begin='1900-01-01') 1660 if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',): 1661 return '-987654321.0' 1662 if flavor == 'oracle' and typ.lower().split('(', maxsplit=1)[0] == 'char': 1663 return "'-987654321'" 1664 if typ.lower() in ('uniqueidentifier', 'guid', 'uuid'): 1665 magic_val = 'DEADBEEF-ABBA-BABE-CAFE-DECAFC0FFEE5' 1666 if flavor == 'mssql': 1667 return f"CAST('{magic_val}' AS UNIQUEIDENTIFIER)" 1668 return f"'{magic_val}'" 1669 return ('n' if flavor == 'oracle' else '') + "'-987654321'"
Return a value that may temporarily be used in place of NULL for this type.
Parameters
- typ (str): The typ to be converted to NULL.
- flavor (str): The database flavor for which this value will be used.
Returns
- A value which may stand in place of NULL for this type.
'None'
is returned if a value cannot be determined.
1672def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]: 1673 """ 1674 Fetch the database version if possible. 1675 """ 1676 version_name = sql_item_name('version', conn.flavor, None) 1677 version_query = version_queries.get( 1678 conn.flavor, 1679 version_queries['default'] 1680 ).format(version_name=version_name) 1681 return conn.value(version_query, debug=debug)
Fetch the database version if possible.
1684def get_rename_table_queries( 1685 old_table: str, 1686 new_table: str, 1687 flavor: str, 1688 schema: Optional[str] = None, 1689) -> List[str]: 1690 """ 1691 Return queries to alter a table's name. 1692 1693 Parameters 1694 ---------- 1695 old_table: str 1696 The unquoted name of the old table. 1697 1698 new_table: str 1699 The unquoted name of the new table. 1700 1701 flavor: str 1702 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1703 1704 schema: Optional[str], default None 1705 The schema on which the table resides. 1706 1707 Returns 1708 ------- 1709 A list of `ALTER TABLE` or equivalent queries for the database flavor. 1710 """ 1711 old_table_name = sql_item_name(old_table, flavor, schema) 1712 new_table_name = sql_item_name(new_table, flavor, None) 1713 tmp_table = '_tmp_rename_' + new_table 1714 tmp_table_name = sql_item_name(tmp_table, flavor, schema) 1715 if flavor == 'mssql': 1716 return [f"EXEC sp_rename '{old_table}', '{new_table}'"] 1717 1718 if_exists_str = "IF EXISTS" if flavor in DROP_IF_EXISTS_FLAVORS else "" 1719 if flavor == 'duckdb': 1720 return ( 1721 get_create_table_queries( 1722 f"SELECT * FROM {old_table_name}", 1723 tmp_table, 1724 'duckdb', 1725 schema, 1726 ) + get_create_table_queries( 1727 f"SELECT * FROM {tmp_table_name}", 1728 new_table, 1729 'duckdb', 1730 schema, 1731 ) + [ 1732 f"DROP TABLE {if_exists_str} {tmp_table_name}", 1733 f"DROP TABLE {if_exists_str} {old_table_name}", 1734 ] 1735 ) 1736 1737 return [f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}"]
Return queries to alter a table's name.
Parameters
- old_table (str): The unquoted name of the old table.
- new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
. - schema (Optional[str], default None): The schema on which the table resides.
Returns
- A list of
ALTER TABLE
or equivalent queries for the database flavor.
1740def get_create_table_query( 1741 query_or_dtypes: Union[str, Dict[str, str]], 1742 new_table: str, 1743 flavor: str, 1744 schema: Optional[str] = None, 1745) -> str: 1746 """ 1747 NOTE: This function is deprecated. Use `get_create_table_queries()` instead. 1748 1749 Return a query to create a new table from a `SELECT` query. 1750 1751 Parameters 1752 ---------- 1753 query: Union[str, Dict[str, str]] 1754 The select query to use for the creation of the table. 1755 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 1756 1757 new_table: str 1758 The unquoted name of the new table. 1759 1760 flavor: str 1761 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 1762 1763 schema: Optional[str], default None 1764 The schema on which the table will reside. 1765 1766 Returns 1767 ------- 1768 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1769 """ 1770 return get_create_table_queries( 1771 query_or_dtypes, 1772 new_table, 1773 flavor, 1774 schema=schema, 1775 primary_key=None, 1776 )[0]
NOTE: This function is deprecated. Use get_create_table_queries()
instead.
Return a query to create a new table from a SELECT
query.
Parameters
- query (Union[str, Dict[str, str]]):
The select query to use for the creation of the table.
If a dictionary is provided, return a
CREATE TABLE
query from the givendtypes
columns. - new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
). - schema (Optional[str], default None): The schema on which the table will reside.
Returns
- A
CREATE TABLE
(orSELECT INTO
) query for the database flavor.
1779def get_create_table_queries( 1780 query_or_dtypes: Union[str, Dict[str, str]], 1781 new_table: str, 1782 flavor: str, 1783 schema: Optional[str] = None, 1784 primary_key: Optional[str] = None, 1785 autoincrement: bool = False, 1786 datetime_column: Optional[str] = None, 1787) -> List[str]: 1788 """ 1789 Return a query to create a new table from a `SELECT` query or a `dtypes` dictionary. 1790 1791 Parameters 1792 ---------- 1793 query_or_dtypes: Union[str, Dict[str, str]] 1794 The select query to use for the creation of the table. 1795 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 1796 1797 new_table: str 1798 The unquoted name of the new table. 1799 1800 flavor: str 1801 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 1802 1803 schema: Optional[str], default None 1804 The schema on which the table will reside. 1805 1806 primary_key: Optional[str], default None 1807 If provided, designate this column as the primary key in the new table. 1808 1809 autoincrement: bool, default False 1810 If `True` and `primary_key` is provided, create the `primary_key` column 1811 as an auto-incrementing integer column. 1812 1813 datetime_column: Optional[str], default None 1814 If provided, include this column in the primary key. 1815 Applicable to TimescaleDB only. 1816 1817 Returns 1818 ------- 1819 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1820 """ 1821 if not isinstance(query_or_dtypes, (str, dict)): 1822 raise TypeError("`query_or_dtypes` must be a query or a dtypes dictionary.") 1823 1824 method = ( 1825 _get_create_table_query_from_cte 1826 if isinstance(query_or_dtypes, str) 1827 else _get_create_table_query_from_dtypes 1828 ) 1829 return method( 1830 query_or_dtypes, 1831 new_table, 1832 flavor, 1833 schema=schema, 1834 primary_key=primary_key, 1835 autoincrement=(autoincrement and flavor not in SKIP_AUTO_INCREMENT_FLAVORS), 1836 datetime_column=datetime_column, 1837 )
Return a query to create a new table from a SELECT
query or a dtypes
dictionary.
Parameters
- query_or_dtypes (Union[str, Dict[str, str]]):
The select query to use for the creation of the table.
If a dictionary is provided, return a
CREATE TABLE
query from the givendtypes
columns. - new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
). - schema (Optional[str], default None): The schema on which the table will reside.
- primary_key (Optional[str], default None): If provided, designate this column as the primary key in the new table.
- autoincrement (bool, default False):
If
True
andprimary_key
is provided, create theprimary_key
column as an auto-incrementing integer column. - datetime_column (Optional[str], default None): If provided, include this column in the primary key. Applicable to TimescaleDB only.
Returns
- A
CREATE TABLE
(orSELECT INTO
) query for the database flavor.
2025def wrap_query_with_cte( 2026 sub_query: str, 2027 parent_query: str, 2028 flavor: str, 2029 cte_name: str = "src", 2030) -> str: 2031 """ 2032 Wrap a subquery in a CTE and append an encapsulating query. 2033 2034 Parameters 2035 ---------- 2036 sub_query: str 2037 The query to be referenced. This may itself contain CTEs. 2038 Unless `cte_name` is provided, this will be aliased as `src`. 2039 2040 parent_query: str 2041 The larger query to append which references the subquery. 2042 This must not contain CTEs. 2043 2044 flavor: str 2045 The database flavor, e.g. `'mssql'`. 2046 2047 cte_name: str, default 'src' 2048 The CTE alias, defaults to `src`. 2049 2050 Returns 2051 ------- 2052 An encapsulating query which allows you to treat `sub_query` as a temporary table. 2053 2054 Examples 2055 -------- 2056 2057 ```python 2058 from meerschaum.utils.sql import wrap_query_with_cte 2059 sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo" 2060 parent_query = "SELECT newval * 3 FROM src" 2061 query = wrap_query_with_cte(sub_query, parent_query, 'mssql') 2062 print(query) 2063 # WITH foo AS (SELECT 1 AS val), 2064 # [src] AS ( 2065 # SELECT (val * 2) AS newval FROM foo 2066 # ) 2067 # SELECT newval * 3 FROM src 2068 ``` 2069 2070 """ 2071 sub_query = sub_query.lstrip() 2072 cte_name_quoted = sql_item_name(cte_name, flavor, None) 2073 2074 if flavor in NO_CTE_FLAVORS: 2075 return ( 2076 parent_query 2077 .replace(cte_name_quoted, '--MRSM_SUBQUERY--') 2078 .replace(cte_name, '--MRSM_SUBQUERY--') 2079 .replace('--MRSM_SUBQUERY--', f"(\n{sub_query}\n) AS {cte_name_quoted}") 2080 ) 2081 2082 if 'with ' in sub_query.lower(): 2083 final_select_ix = sub_query.lower().rfind('select') 2084 return ( 2085 sub_query[:final_select_ix].rstrip() + ',\n' 2086 + f"{cte_name_quoted} AS (\n" 2087 + ' ' + sub_query[final_select_ix:] 2088 + "\n)\n" 2089 + parent_query 2090 ) 2091 2092 return ( 2093 f"WITH {cte_name_quoted} AS (\n" 2094 f" {sub_query}\n" 2095 f")\n{parent_query}" 2096 )
Wrap a subquery in a CTE and append an encapsulating query.
Parameters
- sub_query (str):
The query to be referenced. This may itself contain CTEs.
Unless
cte_name
is provided, this will be aliased assrc
. - parent_query (str): The larger query to append which references the subquery. This must not contain CTEs.
- flavor (str):
The database flavor, e.g.
'mssql'
. - cte_name (str, default 'src'):
The CTE alias, defaults to
src
.
Returns
- An encapsulating query which allows you to treat
sub_query
as a temporary table.
Examples
from meerschaum.utils.sql import wrap_query_with_cte
sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo"
parent_query = "SELECT newval * 3 FROM src"
query = wrap_query_with_cte(sub_query, parent_query, 'mssql')
print(query)
# WITH foo AS (SELECT 1 AS val),
# [src] AS (
# SELECT (val * 2) AS newval FROM foo
# )
# SELECT newval * 3 FROM src
2099def format_cte_subquery( 2100 sub_query: str, 2101 flavor: str, 2102 sub_name: str = 'src', 2103 cols_to_select: Union[List[str], str] = '*', 2104) -> str: 2105 """ 2106 Given a subquery, build a wrapper query that selects from the CTE subquery. 2107 2108 Parameters 2109 ---------- 2110 sub_query: str 2111 The subquery to wrap. 2112 2113 flavor: str 2114 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 2115 2116 sub_name: str, default 'src' 2117 If possible, give this name to the CTE (must be unquoted). 2118 2119 cols_to_select: Union[List[str], str], default '' 2120 If specified, choose which columns to select from the CTE. 2121 If a list of strings is provided, each item will be quoted and joined with commas. 2122 If a string is given, assume it is quoted and insert it into the query. 2123 2124 Returns 2125 ------- 2126 A wrapper query that selects from the CTE. 2127 """ 2128 quoted_sub_name = sql_item_name(sub_name, flavor, None) 2129 cols_str = ( 2130 cols_to_select 2131 if isinstance(cols_to_select, str) 2132 else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select]) 2133 ) 2134 parent_query = ( 2135 f"SELECT {cols_str}\n" 2136 f"FROM {quoted_sub_name}" 2137 ) 2138 return wrap_query_with_cte(sub_query, parent_query, flavor, cte_name=sub_name)
Given a subquery, build a wrapper query that selects from the CTE subquery.
Parameters
- sub_query (str): The subquery to wrap.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
. - sub_name (str, default 'src'): If possible, give this name to the CTE (must be unquoted).
- cols_to_select (Union[List[str], str], default ''): If specified, choose which columns to select from the CTE. If a list of strings is provided, each item will be quoted and joined with commas. If a string is given, assume it is quoted and insert it into the query.
Returns
- A wrapper query that selects from the CTE.
2141def session_execute( 2142 session: 'sqlalchemy.orm.session.Session', 2143 queries: Union[List[str], str], 2144 with_results: bool = False, 2145 debug: bool = False, 2146) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]: 2147 """ 2148 Similar to `SQLConnector.exec_queries()`, execute a list of queries 2149 and roll back when one fails. 2150 2151 Parameters 2152 ---------- 2153 session: sqlalchemy.orm.session.Session 2154 A SQLAlchemy session representing a transaction. 2155 2156 queries: Union[List[str], str] 2157 A query or list of queries to be executed. 2158 If a query fails, roll back the session. 2159 2160 with_results: bool, default False 2161 If `True`, return a list of result objects. 2162 2163 Returns 2164 ------- 2165 A `SuccessTuple` indicating the queries were successfully executed. 2166 If `with_results`, return the `SuccessTuple` and a list of results. 2167 """ 2168 sqlalchemy = mrsm.attempt_import('sqlalchemy') 2169 if not isinstance(queries, list): 2170 queries = [queries] 2171 successes, msgs, results = [], [], [] 2172 for query in queries: 2173 query_text = sqlalchemy.text(query) 2174 fail_msg = "Failed to execute queries." 2175 try: 2176 result = session.execute(query_text) 2177 query_success = result is not None 2178 query_msg = "Success" if query_success else fail_msg 2179 except Exception as e: 2180 query_success = False 2181 query_msg = f"{fail_msg}\n{e}" 2182 result = None 2183 successes.append(query_success) 2184 msgs.append(query_msg) 2185 results.append(result) 2186 if not query_success: 2187 session.rollback() 2188 break 2189 success, msg = all(successes), '\n'.join(msgs) 2190 if with_results: 2191 return (success, msg), results 2192 return success, msg
Similar to SQLConnector.exec_queries()
, execute a list of queries
and roll back when one fails.
Parameters
- session (sqlalchemy.orm.session.Session): A SQLAlchemy session representing a transaction.
- queries (Union[List[str], str]): A query or list of queries to be executed. If a query fails, roll back the session.
- with_results (bool, default False):
If
True
, return a list of result objects.
Returns
- A
SuccessTuple
indicating the queries were successfully executed. - If
with_results
, return theSuccessTuple
and a list of results.
2195def get_reset_autoincrement_queries( 2196 table: str, 2197 column: str, 2198 connector: mrsm.connectors.SQLConnector, 2199 schema: Optional[str] = None, 2200 debug: bool = False, 2201) -> List[str]: 2202 """ 2203 Return a list of queries to reset a table's auto-increment counter to the next largest value. 2204 2205 Parameters 2206 ---------- 2207 table: str 2208 The name of the table on which the auto-incrementing column exists. 2209 2210 column: str 2211 The name of the auto-incrementing column. 2212 2213 connector: mrsm.connectors.SQLConnector 2214 The SQLConnector to the database on which the table exists. 2215 2216 schema: Optional[str], default None 2217 The schema of the table. Defaults to `connector.schema`. 2218 2219 Returns 2220 ------- 2221 A list of queries to be executed to reset the auto-incrementing column. 2222 """ 2223 if not table_exists(table, connector, schema=schema, debug=debug): 2224 return [] 2225 2226 schema = schema or connector.schema 2227 max_id_name = sql_item_name('max_id', connector.flavor) 2228 table_name = sql_item_name(table, connector.flavor, schema) 2229 table_trunc = truncate_item_name(table, connector.flavor) 2230 table_seq_name = sql_item_name(table + '_' + column + '_seq', connector.flavor, schema) 2231 column_name = sql_item_name(column, connector.flavor) 2232 if connector.flavor == 'oracle': 2233 df = connector.read(f""" 2234 SELECT SEQUENCE_NAME 2235 FROM ALL_TAB_IDENTITY_COLS 2236 WHERE TABLE_NAME IN '{table_trunc.upper()}' 2237 """, debug=debug) 2238 if len(df) > 0: 2239 table_seq_name = df['sequence_name'][0] 2240 2241 max_id = connector.value( 2242 f""" 2243 SELECT COALESCE(MAX({column_name}), 0) AS {max_id_name} 2244 FROM {table_name} 2245 """, 2246 debug=debug, 2247 ) 2248 if max_id is None: 2249 return [] 2250 2251 reset_queries = reset_autoincrement_queries.get( 2252 connector.flavor, 2253 reset_autoincrement_queries['default'] 2254 ) 2255 if not isinstance(reset_queries, list): 2256 reset_queries = [reset_queries] 2257 2258 return [ 2259 query.format( 2260 column=column, 2261 column_name=column_name, 2262 table=table, 2263 table_name=table_name, 2264 table_seq_name=table_seq_name, 2265 val=(max_id), 2266 ) 2267 for query in reset_queries 2268 ]
Return a list of queries to reset a table's auto-increment counter to the next largest value.
Parameters
- table (str): The name of the table on which the auto-incrementing column exists.
- column (str): The name of the auto-incrementing column.
- connector (mrsm.connectors.SQLConnector): The SQLConnector to the database on which the table exists.
- schema (Optional[str], default None):
The schema of the table. Defaults to
connector.schema
.
Returns
- A list of queries to be executed to reset the auto-incrementing column.