meerschaum.utils.sql
Flavor-specific SQL tools.
1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# vim:fenc=utf-8 4 5""" 6Flavor-specific SQL tools. 7""" 8 9from __future__ import annotations 10from datetime import datetime, timezone, timedelta 11import meerschaum as mrsm 12from meerschaum.utils.typing import Optional, Dict, Any, Union, List, Iterable, Tuple 13### Preserve legacy imports. 14from meerschaum.utils.dtypes.sql import ( 15 DB_TO_PD_DTYPES, 16 PD_TO_DB_DTYPES_FLAVORS, 17 get_pd_type_from_db_type as get_pd_type, 18 get_db_type_from_pd_type as get_db_type, 19) 20from meerschaum.utils.warnings import warn 21from meerschaum.utils.debug import dprint 22 23test_queries = { 24 'default' : 'SELECT 1', 25 'oracle' : 'SELECT 1 FROM DUAL', 26 'informix' : 'SELECT COUNT(*) FROM systables', 27 'hsqldb' : 'SELECT 1 FROM INFORMATION_SCHEMA.SYSTEM_USERS', 28} 29### `table_name` is the escaped name of the table. 30### `table` is the unescaped name of the table. 31exists_queries = { 32 'default': "SELECT COUNT(*) FROM {table_name} WHERE 1 = 0", 33} 34version_queries = { 35 'default': "SELECT VERSION() AS {version_name}", 36 'sqlite': "SELECT SQLITE_VERSION() AS {version_name}", 37 'mssql': "SELECT @@version", 38 'oracle': "SELECT version from PRODUCT_COMPONENT_VERSION WHERE rownum = 1", 39} 40SKIP_IF_EXISTS_FLAVORS = {'mssql', 'oracle'} 41COALESCE_UNIQUE_INDEX_FLAVORS = {'timescaledb', 'postgresql', 'citus'} 42update_queries = { 43 'default': """ 44 UPDATE {target_table_name} AS f 45 {sets_subquery_none} 46 FROM {target_table_name} AS t 47 INNER JOIN (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) AS p 48 ON {and_subquery_t} 49 WHERE 50 {and_subquery_f} 51 AND {date_bounds_subquery} 52 """, 53 'timescaledb-upsert': """ 54 INSERT INTO {target_table_name} ({patch_cols_str}) 55 SELECT {patch_cols_str} 56 FROM {patch_table_name} 57 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 58 """, 59 'postgresql-upsert': """ 60 INSERT INTO {target_table_name} ({patch_cols_str}) 61 SELECT {patch_cols_str} 62 FROM {patch_table_name} 63 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 64 """, 65 'citus-upsert': """ 66 INSERT INTO {target_table_name} ({patch_cols_str}) 67 SELECT {patch_cols_str} 68 FROM {patch_table_name} 69 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 70 """, 71 'cockroachdb-upsert': """ 72 INSERT INTO {target_table_name} ({patch_cols_str}) 73 SELECT {patch_cols_str} 74 FROM {patch_table_name} 75 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 76 """, 77 'mysql': """ 78 UPDATE {target_table_name} AS f 79 JOIN (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) AS p 80 ON {and_subquery_f} 81 {sets_subquery_f} 82 WHERE {date_bounds_subquery} 83 """, 84 'mysql-upsert': """ 85 REPLACE INTO {target_table_name} ({patch_cols_str}) 86 SELECT {patch_cols_str} 87 FROM {patch_table_name} 88 """, 89 'mariadb': """ 90 UPDATE {target_table_name} AS f 91 JOIN (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) AS p 92 ON {and_subquery_f} 93 {sets_subquery_f} 94 WHERE {date_bounds_subquery} 95 """, 96 'mariadb-upsert': """ 97 REPLACE INTO {target_table_name} ({patch_cols_str}) 98 SELECT {patch_cols_str} 99 FROM {patch_table_name} 100 """, 101 'mssql': """ 102 MERGE {target_table_name} f 103 USING (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) p 104 ON {and_subquery_f} 105 AND {date_bounds_subquery} 106 WHEN MATCHED THEN 107 UPDATE 108 {sets_subquery_none}; 109 """, 110 'oracle': """ 111 MERGE INTO {target_table_name} f 112 USING (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) p 113 ON ( 114 {and_subquery_f} 115 AND {date_bounds_subquery} 116 ) 117 WHEN MATCHED THEN 118 UPDATE 119 {sets_subquery_none} 120 """, 121 'sqlite-upsert': """ 122 INSERT INTO {target_table_name} ({patch_cols_str}) 123 SELECT {patch_cols_str} 124 FROM {patch_table_name} 125 WHERE true 126 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 127 """, 128 'sqlite_delete_insert': [ 129 """ 130 DELETE FROM {target_table_name} AS f 131 WHERE ROWID IN ( 132 SELECT t.ROWID 133 FROM {target_table_name} AS t 134 INNER JOIN (SELECT DISTINCT * FROM {patch_table_name}) AS p 135 ON {and_subquery_t} 136 ); 137 """, 138 """ 139 INSERT INTO {target_table_name} AS f 140 SELECT DISTINCT {patch_cols_str} FROM {patch_table_name} AS p 141 """, 142 ], 143} 144columns_types_queries = { 145 'default': """ 146 SELECT 147 table_catalog AS database, 148 table_schema AS schema, 149 table_name AS table, 150 column_name AS column, 151 data_type AS type 152 FROM information_schema.columns 153 WHERE table_name IN ('{table}', '{table_trunc}') 154 """, 155 'sqlite': """ 156 SELECT 157 '' "database", 158 '' "schema", 159 m.name "table", 160 p.name "column", 161 p.type "type" 162 FROM sqlite_master m 163 LEFT OUTER JOIN pragma_table_info((m.name)) p 164 ON m.name <> p.name 165 WHERE m.type = 'table' 166 AND m.name IN ('{table}', '{table_trunc}') 167 """, 168 'mssql': """ 169 SELECT 170 TABLE_CATALOG AS [database], 171 TABLE_SCHEMA AS [schema], 172 TABLE_NAME AS [table], 173 COLUMN_NAME AS [column], 174 DATA_TYPE AS [type] 175 FROM INFORMATION_SCHEMA.COLUMNS 176 WHERE TABLE_NAME IN ('{table}', '{table_trunc}') 177 """, 178 'mysql': """ 179 SELECT 180 TABLE_SCHEMA `database`, 181 TABLE_SCHEMA `schema`, 182 TABLE_NAME `table`, 183 COLUMN_NAME `column`, 184 DATA_TYPE `type` 185 FROM INFORMATION_SCHEMA.COLUMNS 186 WHERE TABLE_NAME IN ('{table}', '{table_trunc}') 187 """, 188 'mariadb': """ 189 SELECT 190 TABLE_SCHEMA `database`, 191 TABLE_SCHEMA `schema`, 192 TABLE_NAME `table`, 193 COLUMN_NAME `column`, 194 DATA_TYPE `type` 195 FROM INFORMATION_SCHEMA.COLUMNS 196 WHERE TABLE_NAME IN ('{table}', '{table_trunc}') 197 """, 198 'oracle': """ 199 SELECT 200 NULL AS "database", 201 NULL AS "schema", 202 TABLE_NAME AS "table", 203 COLUMN_NAME AS "column", 204 DATA_TYPE AS "type" 205 FROM all_tab_columns 206 WHERE TABLE_NAME IN ( 207 '{table}', 208 '{table_trunc}', 209 '{table_lower}', 210 '{table_lower_trunc}', 211 '{table_upper}', 212 '{table_upper_trunc}' 213 ) 214 """, 215} 216hypertable_queries = { 217 'timescaledb': 'SELECT hypertable_size(\'{table_name}\')', 218 'citus': 'SELECT citus_table_size(\'{table_name}\')', 219} 220table_wrappers = { 221 'default' : ('"', '"'), 222 'timescaledb': ('"', '"'), 223 'citus' : ('"', '"'), 224 'duckdb' : ('"', '"'), 225 'postgresql' : ('"', '"'), 226 'sqlite' : ('"', '"'), 227 'mysql' : ('`', '`'), 228 'mariadb' : ('`', '`'), 229 'mssql' : ('[', ']'), 230 'cockroachdb': ('"', '"'), 231 'oracle' : ('"', '"'), 232} 233max_name_lens = { 234 'default' : 64, 235 'mssql' : 128, 236 'oracle' : 30, 237 'postgresql' : 64, 238 'timescaledb': 64, 239 'citus' : 64, 240 'cockroachdb': 64, 241 'sqlite' : 1024, ### Probably more, but 1024 seems more than reasonable. 242 'mysql' : 64, 243 'mariadb' : 64, 244} 245json_flavors = {'postgresql', 'timescaledb', 'citus', 'cockroachdb'} 246NO_SCHEMA_FLAVORS = {'oracle', 'sqlite', 'mysql', 'mariadb', 'duckdb'} 247DEFAULT_SCHEMA_FLAVORS = { 248 'postgresql': 'public', 249 'timescaledb': 'public', 250 'citus': 'public', 251 'cockroachdb': 'public', 252 'mysql': 'mysql', 253 'mariadb': 'mysql', 254 'mssql': 'dbo', 255} 256OMIT_NULLSFIRST_FLAVORS = {'mariadb', 'mysql', 'mssql'} 257 258SINGLE_ALTER_TABLE_FLAVORS = {'duckdb', 'sqlite', 'mssql', 'oracle'} 259NO_CTE_FLAVORS = {'mysql', 'mariadb'} 260NO_SELECT_INTO_FLAVORS = {'sqlite', 'oracle', 'mysql', 'mariadb', 'duckdb'} 261 262 263def clean(substring: str) -> str: 264 """ 265 Ensure a substring is clean enough to be inserted into a SQL query. 266 Raises an exception when banned words are used. 267 """ 268 from meerschaum.utils.warnings import error 269 banned_symbols = [';', '--', 'drop ',] 270 for symbol in banned_symbols: 271 if symbol in str(substring).lower(): 272 error(f"Invalid string: '{substring}'") 273 274 275def dateadd_str( 276 flavor: str = 'postgresql', 277 datepart: str = 'day', 278 number: Union[int, float] = 0, 279 begin: Union[str, datetime, int] = 'now' 280 ) -> str: 281 """ 282 Generate a `DATEADD` clause depending on database flavor. 283 284 Parameters 285 ---------- 286 flavor: str, default `'postgresql'` 287 SQL database flavor, e.g. `'postgresql'`, `'sqlite'`. 288 289 Currently supported flavors: 290 291 - `'postgresql'` 292 - `'timescaledb'` 293 - `'citus'` 294 - `'cockroachdb'` 295 - `'duckdb'` 296 - `'mssql'` 297 - `'mysql'` 298 - `'mariadb'` 299 - `'sqlite'` 300 - `'oracle'` 301 302 datepart: str, default `'day'` 303 Which part of the date to modify. Supported values: 304 305 - `'year'` 306 - `'month'` 307 - `'day'` 308 - `'hour'` 309 - `'minute'` 310 - `'second'` 311 312 number: Union[int, float], default `0` 313 How many units to add to the date part. 314 315 begin: Union[str, datetime], default `'now'` 316 Base datetime to which to add dateparts. 317 318 Returns 319 ------- 320 The appropriate `DATEADD` string for the corresponding database flavor. 321 322 Examples 323 -------- 324 >>> dateadd_str( 325 ... flavor = 'mssql', 326 ... begin = datetime(2022, 1, 1, 0, 0), 327 ... number = 1, 328 ... ) 329 "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME))" 330 >>> dateadd_str( 331 ... flavor = 'postgresql', 332 ... begin = datetime(2022, 1, 1, 0, 0), 333 ... number = 1, 334 ... ) 335 "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'" 336 337 """ 338 from meerschaum.utils.debug import dprint 339 from meerschaum.utils.packages import attempt_import 340 from meerschaum.utils.warnings import error 341 dateutil_parser = attempt_import('dateutil.parser') 342 if 'int' in str(type(begin)).lower(): 343 return str(begin) 344 if not begin: 345 return '' 346 347 _original_begin = begin 348 begin_time = None 349 ### Sanity check: make sure `begin` is a valid datetime before we inject anything. 350 if not isinstance(begin, datetime): 351 try: 352 begin_time = dateutil_parser.parse(begin) 353 except Exception: 354 begin_time = None 355 else: 356 begin_time = begin 357 358 ### Unable to parse into a datetime. 359 if begin_time is None: 360 ### Throw an error if banned symbols are included in the `begin` string. 361 clean(str(begin)) 362 ### If begin is a valid datetime, wrap it in quotes. 363 else: 364 if isinstance(begin, datetime) and begin.tzinfo is not None: 365 begin = begin.astimezone(timezone.utc) 366 begin = ( 367 f"'{begin.replace(tzinfo=None)}'" 368 if isinstance(begin, datetime) 369 else f"'{begin}'" 370 ) 371 372 da = "" 373 if flavor in ('postgresql', 'timescaledb', 'cockroachdb', 'citus'): 374 begin = ( 375 f"CAST({begin} AS TIMESTAMP)" if begin != 'now' 376 else "CAST(NOW() AT TIME ZONE 'utc' AS TIMESTAMP)" 377 ) 378 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 379 380 elif flavor == 'duckdb': 381 begin = f"CAST({begin} AS TIMESTAMP)" if begin != 'now' else 'NOW()' 382 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 383 384 elif flavor in ('mssql',): 385 if begin_time and begin_time.microsecond != 0: 386 begin = begin[:-4] + "'" 387 begin = f"CAST({begin} AS DATETIME)" if begin != 'now' else 'GETUTCDATE()' 388 da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin 389 390 elif flavor in ('mysql', 'mariadb'): 391 begin = f"CAST({begin} AS DATETIME(6))" if begin != 'now' else 'UTC_TIMESTAMP(6)' 392 da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin) 393 394 elif flavor == 'sqlite': 395 da = f"datetime({begin}, '{number} {datepart}')" 396 397 elif flavor == 'oracle': 398 if begin == 'now': 399 begin = str( 400 datetime.now(timezone.utc).replace(tzinfo=None).strftime('%Y:%m:%d %M:%S.%f') 401 ) 402 elif begin_time: 403 begin = str(begin_time.strftime('%Y-%m-%d %H:%M:%S.%f')) 404 dt_format = 'YYYY-MM-DD HH24:MI:SS.FF' 405 _begin = f"'{begin}'" if begin_time else begin 406 da = ( 407 (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin) 408 + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "") 409 ) 410 return da 411 412 413def test_connection( 414 self, 415 **kw: Any 416 ) -> Union[bool, None]: 417 """ 418 Test if a successful connection to the database may be made. 419 420 Parameters 421 ---------- 422 **kw: 423 The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`. 424 425 Returns 426 ------- 427 `True` if a connection is made, otherwise `False` or `None` in case of failure. 428 429 """ 430 import warnings 431 from meerschaum.connectors.poll import retry_connect 432 _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self} 433 _default_kw.update(kw) 434 with warnings.catch_warnings(): 435 warnings.filterwarnings('ignore', 'Could not') 436 try: 437 return retry_connect(**_default_kw) 438 except Exception as e: 439 return False 440 441 442def get_distinct_col_count( 443 col: str, 444 query: str, 445 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 446 debug: bool = False 447 ) -> Optional[int]: 448 """ 449 Returns the number of distinct items in a column of a SQL query. 450 451 Parameters 452 ---------- 453 col: str: 454 The column in the query to count. 455 456 query: str: 457 The SQL query to count from. 458 459 connector: Optional[mrsm.connectors.sql.SQLConnector], default None: 460 The SQLConnector to execute the query. 461 462 debug: bool, default False: 463 Verbosity toggle. 464 465 Returns 466 ------- 467 An `int` of the number of columns in the query or `None` if the query fails. 468 469 """ 470 if connector is None: 471 connector = mrsm.get_connector('sql') 472 473 _col_name = sql_item_name(col, connector.flavor, None) 474 475 _meta_query = ( 476 f""" 477 WITH src AS ( {query} ), 478 dist AS ( SELECT DISTINCT {_col_name} FROM src ) 479 SELECT COUNT(*) FROM dist""" 480 ) if connector.flavor not in ('mysql', 'mariadb') else ( 481 f""" 482 SELECT COUNT(*) 483 FROM ( 484 SELECT DISTINCT {_col_name} 485 FROM ({query}) AS src 486 ) AS dist""" 487 ) 488 489 result = connector.value(_meta_query, debug=debug) 490 try: 491 return int(result) 492 except Exception as e: 493 return None 494 495 496def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str: 497 """ 498 Parse SQL items depending on the flavor. 499 500 Parameters 501 ---------- 502 item: str : 503 The database item (table, view, etc.) in need of quotes. 504 505 flavor: str : 506 The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.). 507 508 Returns 509 ------- 510 A `str` which contains the input `item` wrapped in the corresponding escape characters. 511 512 Examples 513 -------- 514 >>> sql_item_name('table', 'sqlite') 515 '"table"' 516 >>> sql_item_name('table', 'mssql') 517 "[table]" 518 >>> sql_item_name('table', 'postgresql', schema='abc') 519 '"abc"."table"' 520 521 """ 522 truncated_item = truncate_item_name(str(item), flavor) 523 if flavor == 'oracle': 524 truncated_item = pg_capital(truncated_item) 525 ### NOTE: System-reserved words must be quoted. 526 if truncated_item.lower() in ( 527 'float', 'varchar', 'nvarchar', 'clob', 528 'boolean', 'integer', 'table', 529 ): 530 wrappers = ('"', '"') 531 else: 532 wrappers = ('', '') 533 else: 534 wrappers = table_wrappers.get(flavor, table_wrappers['default']) 535 536 ### NOTE: SQLite does not support schemas. 537 if flavor == 'sqlite': 538 schema = None 539 540 schema_prefix = ( 541 (wrappers[0] + schema + wrappers[1] + '.') 542 if schema is not None 543 else '' 544 ) 545 546 return schema_prefix + wrappers[0] + truncated_item + wrappers[1] 547 548 549def pg_capital(s: str) -> str: 550 """ 551 If string contains a capital letter, wrap it in double quotes. 552 553 Parameters 554 ---------- 555 s: str : 556 The string to be escaped. 557 558 Returns 559 ------- 560 The input string wrapped in quotes only if it needs them. 561 562 Examples 563 -------- 564 >>> pg_capital("My Table") 565 '"My Table"' 566 >>> pg_capital('my_table') 567 'my_table' 568 569 """ 570 if '"' in s: 571 return s 572 needs_quotes = s.startswith('_') 573 for c in str(s): 574 if ord(c) < ord('a') or ord(c) > ord('z'): 575 if not c.isdigit() and c != '_': 576 needs_quotes = True 577 break 578 if needs_quotes: 579 return '"' + s + '"' 580 return s 581 582 583def oracle_capital(s: str) -> str: 584 """ 585 Capitalize the string of an item on an Oracle database. 586 """ 587 return s 588 589 590def truncate_item_name(item: str, flavor: str) -> str: 591 """ 592 Truncate item names to stay within the database flavor's character limit. 593 594 Parameters 595 ---------- 596 item: str 597 The database item being referenced. This string is the "canonical" name internally. 598 599 flavor: str 600 The flavor of the database on which `item` resides. 601 602 Returns 603 ------- 604 The truncated string. 605 """ 606 from meerschaum.utils.misc import truncate_string_sections 607 return truncate_string_sections( 608 item, max_len=max_name_lens.get(flavor, max_name_lens['default']) 609 ) 610 611 612def build_where( 613 params: Dict[str, Any], 614 connector: Optional[meerschaum.connectors.sql.SQLConnector] = None, 615 with_where: bool = True, 616 ) -> str: 617 """ 618 Build the `WHERE` clause based on the input criteria. 619 620 Parameters 621 ---------- 622 params: Dict[str, Any]: 623 The keywords dictionary to convert into a WHERE clause. 624 If a value is a string which begins with an underscore, negate that value 625 (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`). 626 A value of `_None` will be interpreted as `IS NOT NULL`. 627 628 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 629 The Meerschaum SQLConnector that will be executing the query. 630 The connector is used to extract the SQL dialect. 631 632 with_where: bool, default True: 633 If `True`, include the leading `'WHERE'` string. 634 635 Returns 636 ------- 637 A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor. 638 639 Examples 640 -------- 641 ``` 642 >>> print(build_where({'foo': [1, 2, 3]})) 643 644 WHERE 645 "foo" IN ('1', '2', '3') 646 ``` 647 """ 648 import json 649 from meerschaum.config.static import STATIC_CONFIG 650 from meerschaum.utils.warnings import warn 651 from meerschaum.utils.dtypes import value_is_null, none_if_null 652 negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix'] 653 try: 654 params_json = json.dumps(params) 655 except Exception as e: 656 params_json = str(params) 657 bad_words = ['drop ', '--', ';'] 658 for word in bad_words: 659 if word in params_json.lower(): 660 warn(f"Aborting build_where() due to possible SQL injection.") 661 return '' 662 663 if connector is None: 664 from meerschaum import get_connector 665 connector = get_connector('sql') 666 where = "" 667 leading_and = "\n AND " 668 for key, value in params.items(): 669 _key = sql_item_name(key, connector.flavor, None) 670 ### search across a list (i.e. IN syntax) 671 if isinstance(value, Iterable) and not isinstance(value, (dict, str)): 672 includes = [ 673 none_if_null(item) 674 for item in value 675 if not str(item).startswith(negation_prefix) 676 ] 677 null_includes = [item for item in includes if item is None] 678 not_null_includes = [item for item in includes if item is not None] 679 excludes = [ 680 none_if_null(str(item)[len(negation_prefix):]) 681 for item in value 682 if str(item).startswith(negation_prefix) 683 ] 684 null_excludes = [item for item in excludes if item is None] 685 not_null_excludes = [item for item in excludes if item is not None] 686 687 if includes: 688 where += f"{leading_and}(" 689 if not_null_includes: 690 where += f"{_key} IN (" 691 for item in not_null_includes: 692 quoted_item = str(item).replace("'", "''") 693 where += f"'{quoted_item}', " 694 where = where[:-2] + ")" 695 if null_includes: 696 where += ("\n OR " if not_null_includes else "") + f"{_key} IS NULL" 697 if includes: 698 where += ")" 699 700 if excludes: 701 where += f"{leading_and}(" 702 if not_null_excludes: 703 where += f"{_key} NOT IN (" 704 for item in not_null_excludes: 705 quoted_item = str(item).replace("'", "''") 706 where += f"'{quoted_item}', " 707 where = where[:-2] + ")" 708 if null_excludes: 709 where += ("\n AND " if not_null_excludes else "") + f"{_key} IS NOT NULL" 710 if excludes: 711 where += ")" 712 713 continue 714 715 ### search a dictionary 716 elif isinstance(value, dict): 717 import json 718 where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'") 719 continue 720 721 eq_sign = '=' 722 is_null = 'IS NULL' 723 if value_is_null(str(value).lstrip(negation_prefix)): 724 value = ( 725 (negation_prefix + 'None') 726 if str(value).startswith(negation_prefix) 727 else None 728 ) 729 if str(value).startswith(negation_prefix): 730 value = str(value)[len(negation_prefix):] 731 eq_sign = '!=' 732 if value_is_null(value): 733 value = None 734 is_null = 'IS NOT NULL' 735 quoted_value = str(value).replace("'", "''") 736 where += ( 737 f"{leading_and}{_key} " 738 + (is_null if value is None else f"{eq_sign} '{quoted_value}'") 739 ) 740 741 if len(where) > 1: 742 where = ("\nWHERE\n " if with_where else '') + where[len(leading_and):] 743 return where 744 745 746def table_exists( 747 table: str, 748 connector: mrsm.connectors.sql.SQLConnector, 749 schema: Optional[str] = None, 750 debug: bool = False, 751 ) -> bool: 752 """Check if a table exists. 753 754 Parameters 755 ---------- 756 table: str: 757 The name of the table in question. 758 759 connector: mrsm.connectors.sql.SQLConnector 760 The connector to the database which holds the table. 761 762 schema: Optional[str], default None 763 Optionally specify the table schema. 764 Defaults to `connector.schema`. 765 766 debug: bool, default False : 767 Verbosity toggle. 768 769 Returns 770 ------- 771 A `bool` indicating whether or not the table exists on the database. 772 773 """ 774 sqlalchemy = mrsm.attempt_import('sqlalchemy') 775 schema = schema or connector.schema 776 insp = sqlalchemy.inspect(connector.engine) 777 truncated_table_name = truncate_item_name(str(table), connector.flavor) 778 return insp.has_table(truncated_table_name, schema=schema) 779 780 781def get_sqlalchemy_table( 782 table: str, 783 connector: Optional[meerschaum.connectors.sql.SQLConnector] = None, 784 schema: Optional[str] = None, 785 refresh: bool = False, 786 debug: bool = False, 787 ) -> 'sqlalchemy.Table': 788 """ 789 Construct a SQLAlchemy table from its name. 790 791 Parameters 792 ---------- 793 table: str 794 The name of the table on the database. Does not need to be escaped. 795 796 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 797 The connector to the database which holds the table. 798 799 schema: Optional[str], default None 800 Specify on which schema the table resides. 801 Defaults to the schema set in `connector`. 802 803 refresh: bool, default False 804 If `True`, rebuild the cached table object. 805 806 debug: bool, default False: 807 Verbosity toggle. 808 809 Returns 810 ------- 811 A `sqlalchemy.Table` object for the table. 812 813 """ 814 if connector is None: 815 from meerschaum import get_connector 816 connector = get_connector('sql') 817 818 from meerschaum.connectors.sql.tables import get_tables 819 from meerschaum.utils.packages import attempt_import 820 from meerschaum.utils.warnings import warn 821 if refresh: 822 connector.metadata.clear() 823 tables = get_tables(mrsm_instance=connector, debug=debug, create=False) 824 sqlalchemy = attempt_import('sqlalchemy') 825 truncated_table_name = truncate_item_name(str(table), connector.flavor) 826 table_kwargs = { 827 'autoload_with': connector.engine, 828 } 829 if schema: 830 table_kwargs['schema'] = schema 831 832 if refresh or truncated_table_name not in tables: 833 try: 834 tables[truncated_table_name] = sqlalchemy.Table( 835 truncated_table_name, 836 connector.metadata, 837 **table_kwargs 838 ) 839 except sqlalchemy.exc.NoSuchTableError as e: 840 warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.") 841 return None 842 return tables[truncated_table_name] 843 844 845def get_table_cols_types( 846 table: str, 847 connectable: Union[ 848 'mrsm.connectors.sql.SQLConnector', 849 'sqlalchemy.orm.session.Session', 850 'sqlalchemy.engine.base.Engine' 851 ], 852 flavor: Optional[str] = None, 853 schema: Optional[str] = None, 854 database: Optional[str] = None, 855 debug: bool = False, 856 ) -> Dict[str, str]: 857 """ 858 Return a dictionary mapping a table's columns to data types. 859 This is useful for inspecting tables creating during a not-yet-committed session. 860 861 NOTE: This may return incorrect columns if the schema is not explicitly stated. 862 Use this function if you are confident the table name is unique or if you have 863 and explicit schema. 864 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 865 866 Parameters 867 ---------- 868 table: str 869 The name of the table (unquoted). 870 871 connectable: Union[ 872 'mrsm.connectors.sql.SQLConnector', 873 'sqlalchemy.orm.session.Session', 874 ] 875 The connection object used to fetch the columns and types. 876 877 flavor: Optional[str], default None 878 The database dialect flavor to use for the query. 879 If omitted, default to `connectable.flavor`. 880 881 schema: Optional[str], default None 882 If provided, restrict the query to this schema. 883 884 database: Optional[str]. default None 885 If provided, restrict the query to this database. 886 887 Returns 888 ------- 889 A dictionary mapping column names to data types. 890 """ 891 from meerschaum.connectors import SQLConnector 892 from meerschaum.utils.misc import filter_keywords 893 sqlalchemy = mrsm.attempt_import('sqlalchemy') 894 flavor = flavor or getattr(connectable, 'flavor', None) 895 if not flavor: 896 raise ValueError(f"Please provide a database flavor.") 897 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 898 raise ValueError(f"You must provide a SQLConnector when using DuckDB.") 899 if flavor in NO_SCHEMA_FLAVORS: 900 schema = None 901 if schema is None: 902 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 903 if flavor in ('sqlite', 'duckdb', 'oracle'): 904 database = None 905 table_trunc = truncate_item_name(table, flavor=flavor) 906 table_lower = table.lower() 907 table_upper = table.upper() 908 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 909 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 910 911 cols_types_query = sqlalchemy.text( 912 columns_types_queries.get( 913 flavor, 914 columns_types_queries['default'] 915 ).format( 916 table = table, 917 table_trunc = table_trunc, 918 table_lower = table_lower, 919 table_lower_trunc = table_lower_trunc, 920 table_upper = table_upper, 921 table_upper_trunc = table_upper_trunc, 922 ) 923 ) 924 925 cols = ['database', 'schema', 'table', 'column', 'type'] 926 result_cols_ix = dict(enumerate(cols)) 927 928 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 929 if not debug_kwargs and debug: 930 dprint(cols_types_query) 931 932 try: 933 result_rows = ( 934 [ 935 row 936 for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall() 937 ] 938 if flavor != 'duckdb' 939 else [ 940 tuple([doc[col] for col in cols]) 941 for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records') 942 ] 943 ) 944 cols_types_docs = [ 945 { 946 result_cols_ix[i]: val 947 for i, val in enumerate(row) 948 } 949 for row in result_rows 950 ] 951 cols_types_docs_filtered = [ 952 doc 953 for doc in cols_types_docs 954 if ( 955 ( 956 not schema 957 or doc['schema'] == schema 958 ) 959 and 960 ( 961 not database 962 or doc['database'] == database 963 ) 964 ) 965 ] 966 967 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 968 if cols_types_docs and not cols_types_docs_filtered: 969 cols_types_docs_filtered = cols_types_docs 970 971 return { 972 ( 973 doc['column'] 974 if flavor != 'oracle' else ( 975 ( 976 doc['column'].lower() 977 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 978 else doc['column'] 979 ) 980 ) 981 ): doc['type'].upper() 982 for doc in cols_types_docs_filtered 983 } 984 except Exception as e: 985 warn(f"Failed to fetch columns for table '{table}':\n{e}") 986 return {} 987 988 989def get_update_queries( 990 target: str, 991 patch: str, 992 connectable: Union[ 993 mrsm.connectors.sql.SQLConnector, 994 'sqlalchemy.orm.session.Session' 995 ], 996 join_cols: Iterable[str], 997 flavor: Optional[str] = None, 998 upsert: bool = False, 999 datetime_col: Optional[str] = None, 1000 schema: Optional[str] = None, 1001 patch_schema: Optional[str] = None, 1002 debug: bool = False, 1003 ) -> List[str]: 1004 """ 1005 Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table. 1006 1007 Parameters 1008 ---------- 1009 target: str 1010 The name of the target table. 1011 1012 patch: str 1013 The name of the patch table. This should have the same shape as the target. 1014 1015 connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session] 1016 The `SQLConnector` or SQLAlchemy session which will later execute the queries. 1017 1018 join_cols: List[str] 1019 The columns to use to join the patch to the target. 1020 1021 flavor: Optional[str], default None 1022 If using a SQLAlchemy session, provide the expected database flavor. 1023 1024 upsert: bool, default False 1025 If `True`, return an upsert query rather than an update. 1026 1027 datetime_col: Optional[str], default None 1028 If provided, bound the join query using this column as the datetime index. 1029 This must be present on both tables. 1030 1031 schema: Optional[str], default None 1032 If provided, use this schema when quoting the target table. 1033 Defaults to `connector.schema`. 1034 1035 patch_schema: Optional[str], default None 1036 If provided, use this schema when quoting the patch table. 1037 Defaults to `schema`. 1038 1039 debug: bool, default False 1040 Verbosity toggle. 1041 1042 Returns 1043 ------- 1044 A list of query strings to perform the update operation. 1045 """ 1046 from meerschaum.connectors import SQLConnector 1047 from meerschaum.utils.debug import dprint 1048 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1049 flavor = flavor or (connectable.flavor if isinstance(connectable, SQLConnector) else None) 1050 if not flavor: 1051 raise ValueError("Provide a flavor if using a SQLAlchemy session.") 1052 if ( 1053 flavor == 'sqlite' 1054 and isinstance(connectable, SQLConnector) 1055 and connectable.db_version < '3.33.0' 1056 ): 1057 flavor = 'sqlite_delete_insert' 1058 flavor_key = (f'{flavor}-upsert' if upsert else flavor) 1059 base_queries = update_queries.get( 1060 flavor_key, 1061 update_queries['default'] 1062 ) 1063 if not isinstance(base_queries, list): 1064 base_queries = [base_queries] 1065 schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None) 1066 patch_schema = patch_schema or schema 1067 target_table_columns = get_table_cols_types( 1068 target, 1069 connectable, 1070 flavor = flavor, 1071 schema = schema, 1072 debug = debug, 1073 ) 1074 patch_table_columns = get_table_cols_types( 1075 patch, 1076 connectable, 1077 flavor = flavor, 1078 schema = patch_schema, 1079 debug = debug, 1080 ) 1081 1082 patch_cols_str = ', '.join( 1083 [ 1084 sql_item_name(col, flavor) 1085 for col in patch_table_columns 1086 ] 1087 ) 1088 join_cols_str = ', '.join( 1089 [ 1090 sql_item_name(col, flavor) 1091 for col in join_cols 1092 ] 1093 ) 1094 1095 value_cols = [] 1096 join_cols_types = [] 1097 if debug: 1098 dprint(f"target_table_columns:") 1099 mrsm.pprint(target_table_columns) 1100 for c_name, c_type in target_table_columns.items(): 1101 if c_name not in patch_table_columns: 1102 continue 1103 if flavor in DB_FLAVORS_CAST_DTYPES: 1104 c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type) 1105 ( 1106 join_cols_types 1107 if c_name in join_cols 1108 else value_cols 1109 ).append((c_name, c_type)) 1110 if debug: 1111 dprint(f"value_cols: {value_cols}") 1112 1113 if not join_cols_types: 1114 return [] 1115 if not value_cols and not upsert: 1116 return [] 1117 1118 coalesce_join_cols_str = ', '.join( 1119 [ 1120 'COALESCE(' 1121 + sql_item_name(c_name, flavor) 1122 + ', ' 1123 + get_null_replacement(c_type, flavor) 1124 + ')' 1125 for c_name, c_type in join_cols_types 1126 ] 1127 ) 1128 1129 update_or_nothing = ('UPDATE' if value_cols else 'NOTHING') 1130 1131 def sets_subquery(l_prefix: str, r_prefix: str): 1132 if not value_cols: 1133 return '' 1134 return 'SET ' + ',\n'.join([ 1135 ( 1136 l_prefix + sql_item_name(c_name, flavor, None) 1137 + ' = ' 1138 + ('CAST(' if flavor != 'sqlite' else '') 1139 + r_prefix 1140 + sql_item_name(c_name, flavor, None) 1141 + (' AS ' if flavor != 'sqlite' else '') 1142 + (c_type.replace('_', ' ') if flavor != 'sqlite' else '') 1143 + (')' if flavor != 'sqlite' else '') 1144 ) for c_name, c_type in value_cols 1145 ]) 1146 1147 def and_subquery(l_prefix: str, r_prefix: str): 1148 return '\nAND\n'.join([ 1149 ( 1150 "COALESCE(" 1151 + l_prefix 1152 + sql_item_name(c_name, flavor, None) 1153 + ", " 1154 + get_null_replacement(c_type, flavor) 1155 + ")" 1156 + ' = ' 1157 + "COALESCE(" 1158 + r_prefix 1159 + sql_item_name(c_name, flavor, None) 1160 + ", " 1161 + get_null_replacement(c_type, flavor) 1162 + ")" 1163 ) for c_name, c_type in join_cols_types 1164 ]) 1165 1166 target_table_name = sql_item_name(target, flavor, schema) 1167 patch_table_name = sql_item_name(patch, flavor, patch_schema) 1168 dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None 1169 date_bounds_subquery = ( 1170 f""" 1171 f.{dt_col_name} >= (SELECT MIN({dt_col_name}) FROM {patch_table_name}) 1172 AND f.{dt_col_name} <= (SELECT MAX({dt_col_name}) FROM {patch_table_name}) 1173 """ 1174 if datetime_col 1175 else "1 = 1" 1176 ) 1177 1178 return [ 1179 base_query.format( 1180 sets_subquery_none = sets_subquery('', 'p.'), 1181 sets_subquery_none_excluded = sets_subquery('', 'EXCLUDED.'), 1182 sets_subquery_f = sets_subquery('f.', 'p.'), 1183 and_subquery_f = and_subquery('p.', 'f.'), 1184 and_subquery_t = and_subquery('p.', 't.'), 1185 target_table_name = target_table_name, 1186 patch_table_name = patch_table_name, 1187 patch_cols_str = patch_cols_str, 1188 date_bounds_subquery = date_bounds_subquery, 1189 join_cols_str = join_cols_str, 1190 coalesce_join_cols_str = coalesce_join_cols_str, 1191 update_or_nothing = update_or_nothing, 1192 ) 1193 for base_query in base_queries 1194 ] 1195 1196 1197def get_null_replacement(typ: str, flavor: str) -> str: 1198 """ 1199 Return a value that may temporarily be used in place of NULL for this type. 1200 1201 Parameters 1202 ---------- 1203 typ: str 1204 The typ to be converted to NULL. 1205 1206 flavor: str 1207 The database flavor for which this value will be used. 1208 1209 Returns 1210 ------- 1211 A value which may stand in place of NULL for this type. 1212 `'None'` is returned if a value cannot be determined. 1213 """ 1214 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1215 if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'): 1216 return '-987654321' 1217 if 'bool' in typ.lower(): 1218 bool_typ = ( 1219 PD_TO_DB_DTYPES_FLAVORS 1220 .get('bool', {}) 1221 .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default']) 1222 ) 1223 if flavor in DB_FLAVORS_CAST_DTYPES: 1224 bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ) 1225 val_to_cast = ( 1226 -987654321 1227 if flavor in ('mysql', 'mariadb', 'sqlite', 'mssql') 1228 else 0 1229 ) 1230 return f'CAST({val_to_cast} AS {bool_typ})' 1231 if 'time' in typ.lower() or 'date' in typ.lower(): 1232 return dateadd_str(flavor=flavor, begin='1900-01-01') 1233 if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',): 1234 return '-987654321.0' 1235 return ('n' if flavor == 'oracle' else '') + "'-987654321'" 1236 1237 1238def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]: 1239 """ 1240 Fetch the database version if possible. 1241 """ 1242 version_name = sql_item_name('version', conn.flavor, None) 1243 version_query = version_queries.get( 1244 conn.flavor, 1245 version_queries['default'] 1246 ).format(version_name=version_name) 1247 return conn.value(version_query, debug=debug) 1248 1249 1250def get_rename_table_queries( 1251 old_table: str, 1252 new_table: str, 1253 flavor: str, 1254 schema: Optional[str] = None, 1255 ) -> List[str]: 1256 """ 1257 Return queries to alter a table's name. 1258 1259 Parameters 1260 ---------- 1261 old_table: str 1262 The unquoted name of the old table. 1263 1264 new_table: str 1265 The unquoted name of the new table. 1266 1267 flavor: str 1268 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1269 1270 schema: Optional[str], default None 1271 The schema on which the table resides. 1272 1273 Returns 1274 ------- 1275 A list of `ALTER TABLE` or equivalent queries for the database flavor. 1276 """ 1277 old_table_name = sql_item_name(old_table, flavor, schema) 1278 new_table_name = sql_item_name(new_table, flavor, None) 1279 tmp_table = '_tmp_rename_' + new_table 1280 tmp_table_name = sql_item_name(tmp_table, flavor, schema) 1281 if flavor == 'mssql': 1282 return [f"EXEC sp_rename '{old_table}', '{new_table}'"] 1283 1284 if flavor == 'duckdb': 1285 return [ 1286 get_create_table_query(f"SELECT * FROM {old_table_name}", tmp_table, 'duckdb', schema), 1287 get_create_table_query(f"SELECT * FROM {tmp_table_name}", new_table, 'duckdb', schema), 1288 f"DROP TABLE {tmp_table_name}", 1289 f"DROP TABLE {old_table_name}", 1290 ] 1291 1292 return [f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}"] 1293 1294 1295def get_create_table_query( 1296 query: str, 1297 new_table: str, 1298 flavor: str, 1299 schema: Optional[str] = None, 1300 ) -> str: 1301 """ 1302 Return a query to create a new table from a `SELECT` query. 1303 1304 Parameters 1305 ---------- 1306 query: str 1307 The select query to use for the creation of the table. 1308 1309 new_table: str 1310 The unquoted name of the new table. 1311 1312 flavor: str 1313 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1314 1315 schema: Optional[str], default None 1316 The schema on which the table will reside. 1317 1318 Returns 1319 ------- 1320 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1321 """ 1322 import textwrap 1323 create_cte = 'create_query' 1324 create_cte_name = sql_item_name(create_cte, flavor, None) 1325 new_table_name = sql_item_name(new_table, flavor, schema) 1326 if flavor in ('mssql',): 1327 query = query.lstrip() 1328 original_query = query 1329 if 'with ' in query.lower(): 1330 final_select_ix = query.lower().rfind('select') 1331 def_name = query[len('WITH '):].split(' ', maxsplit=1)[0] 1332 return ( 1333 query[:final_select_ix].rstrip() + ',\n' 1334 + f"{create_cte_name} AS (\n" 1335 + query[final_select_ix:] 1336 + "\n)\n" 1337 + f"SELECT *\nINTO {new_table_name}\nFROM {create_cte_name}" 1338 ) 1339 1340 create_table_query = f""" 1341 SELECT * 1342 INTO {new_table_name} 1343 FROM ({query}) AS {create_cte_name} 1344 """ 1345 elif flavor in (None,): 1346 create_table_query = f""" 1347 WITH {create_cte_name} AS ({query}) 1348 CREATE TABLE {new_table_name} AS 1349 SELECT * 1350 FROM {create_cte_name} 1351 """ 1352 elif flavor in ('sqlite', 'mysql', 'mariadb', 'duckdb', 'oracle'): 1353 create_table_query = f""" 1354 CREATE TABLE {new_table_name} AS 1355 SELECT * 1356 FROM ({query})""" + (f""" AS {create_cte_name}""" if flavor != 'oracle' else '') + """ 1357 """ 1358 else: 1359 create_table_query = f""" 1360 SELECT * 1361 INTO {new_table_name} 1362 FROM ({query}) AS {create_cte_name} 1363 """ 1364 1365 return textwrap.dedent(create_table_query) 1366 1367 1368def format_cte_subquery( 1369 sub_query: str, 1370 flavor: str, 1371 sub_name: str = 'src', 1372 cols_to_select: Union[List[str], str] = '*', 1373 ) -> str: 1374 """ 1375 Given a subquery, build a wrapper query that selects from the CTE subquery. 1376 1377 Parameters 1378 ---------- 1379 sub_query: str 1380 The subquery to wrap. 1381 1382 flavor: str 1383 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1384 1385 sub_name: str, default 'src' 1386 If possible, give this name to the CTE (must be unquoted). 1387 1388 cols_to_select: Union[List[str], str], default '' 1389 If specified, choose which columns to select from the CTE. 1390 If a list of strings is provided, each item will be quoted and joined with commas. 1391 If a string is given, assume it is quoted and insert it into the query. 1392 1393 Returns 1394 ------- 1395 A wrapper query that selects from the CTE. 1396 """ 1397 import textwrap 1398 quoted_sub_name = sql_item_name(sub_name, flavor, None) 1399 cols_str = ( 1400 cols_to_select 1401 if isinstance(cols_to_select, str) 1402 else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select]) 1403 ) 1404 return textwrap.dedent( 1405 f""" 1406 SELECT {cols_str} 1407 FROM ({sub_query})""" 1408 + (f' AS {quoted_sub_name}' if flavor != 'oracle' else '') + """ 1409 """ 1410 ) 1411 1412 1413def session_execute( 1414 session: 'sqlalchemy.orm.session.Session', 1415 queries: Union[List[str], str], 1416 with_results: bool = False, 1417 debug: bool = False, 1418 ) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]: 1419 """ 1420 Similar to `SQLConnector.exec_queries()`, execute a list of queries 1421 and roll back when one fails. 1422 1423 Parameters 1424 ---------- 1425 session: sqlalchemy.orm.session.Session 1426 A SQLAlchemy session representing a transaction. 1427 1428 queries: Union[List[str], str] 1429 A query or list of queries to be executed. 1430 If a query fails, roll back the session. 1431 1432 with_results: bool, default False 1433 If `True`, return a list of result objects. 1434 1435 Returns 1436 ------- 1437 A `SuccessTuple` indicating the queries were successfully executed. 1438 If `with_results`, return the `SuccessTuple` and a list of results. 1439 """ 1440 sqlalchemy = mrsm.attempt_import('sqlalchemy') 1441 if not isinstance(queries, list): 1442 queries = [queries] 1443 successes, msgs, results = [], [], [] 1444 for query in queries: 1445 query_text = sqlalchemy.text(query) 1446 fail_msg = f"Failed to execute queries." 1447 try: 1448 result = session.execute(query_text) 1449 query_success = result is not None 1450 query_msg = "Success" if query_success else fail_msg 1451 except Exception as e: 1452 query_success = False 1453 query_msg = f"{fail_msg}\n{e}" 1454 result = None 1455 successes.append(query_success) 1456 msgs.append(query_msg) 1457 results.append(result) 1458 if not query_success: 1459 session.rollback() 1460 break 1461 success, msg = all(successes), '\n'.join(msgs) 1462 if with_results: 1463 return (success, msg), results 1464 return success, msg
264def clean(substring: str) -> str: 265 """ 266 Ensure a substring is clean enough to be inserted into a SQL query. 267 Raises an exception when banned words are used. 268 """ 269 from meerschaum.utils.warnings import error 270 banned_symbols = [';', '--', 'drop ',] 271 for symbol in banned_symbols: 272 if symbol in str(substring).lower(): 273 error(f"Invalid string: '{substring}'")
Ensure a substring is clean enough to be inserted into a SQL query. Raises an exception when banned words are used.
276def dateadd_str( 277 flavor: str = 'postgresql', 278 datepart: str = 'day', 279 number: Union[int, float] = 0, 280 begin: Union[str, datetime, int] = 'now' 281 ) -> str: 282 """ 283 Generate a `DATEADD` clause depending on database flavor. 284 285 Parameters 286 ---------- 287 flavor: str, default `'postgresql'` 288 SQL database flavor, e.g. `'postgresql'`, `'sqlite'`. 289 290 Currently supported flavors: 291 292 - `'postgresql'` 293 - `'timescaledb'` 294 - `'citus'` 295 - `'cockroachdb'` 296 - `'duckdb'` 297 - `'mssql'` 298 - `'mysql'` 299 - `'mariadb'` 300 - `'sqlite'` 301 - `'oracle'` 302 303 datepart: str, default `'day'` 304 Which part of the date to modify. Supported values: 305 306 - `'year'` 307 - `'month'` 308 - `'day'` 309 - `'hour'` 310 - `'minute'` 311 - `'second'` 312 313 number: Union[int, float], default `0` 314 How many units to add to the date part. 315 316 begin: Union[str, datetime], default `'now'` 317 Base datetime to which to add dateparts. 318 319 Returns 320 ------- 321 The appropriate `DATEADD` string for the corresponding database flavor. 322 323 Examples 324 -------- 325 >>> dateadd_str( 326 ... flavor = 'mssql', 327 ... begin = datetime(2022, 1, 1, 0, 0), 328 ... number = 1, 329 ... ) 330 "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME))" 331 >>> dateadd_str( 332 ... flavor = 'postgresql', 333 ... begin = datetime(2022, 1, 1, 0, 0), 334 ... number = 1, 335 ... ) 336 "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'" 337 338 """ 339 from meerschaum.utils.debug import dprint 340 from meerschaum.utils.packages import attempt_import 341 from meerschaum.utils.warnings import error 342 dateutil_parser = attempt_import('dateutil.parser') 343 if 'int' in str(type(begin)).lower(): 344 return str(begin) 345 if not begin: 346 return '' 347 348 _original_begin = begin 349 begin_time = None 350 ### Sanity check: make sure `begin` is a valid datetime before we inject anything. 351 if not isinstance(begin, datetime): 352 try: 353 begin_time = dateutil_parser.parse(begin) 354 except Exception: 355 begin_time = None 356 else: 357 begin_time = begin 358 359 ### Unable to parse into a datetime. 360 if begin_time is None: 361 ### Throw an error if banned symbols are included in the `begin` string. 362 clean(str(begin)) 363 ### If begin is a valid datetime, wrap it in quotes. 364 else: 365 if isinstance(begin, datetime) and begin.tzinfo is not None: 366 begin = begin.astimezone(timezone.utc) 367 begin = ( 368 f"'{begin.replace(tzinfo=None)}'" 369 if isinstance(begin, datetime) 370 else f"'{begin}'" 371 ) 372 373 da = "" 374 if flavor in ('postgresql', 'timescaledb', 'cockroachdb', 'citus'): 375 begin = ( 376 f"CAST({begin} AS TIMESTAMP)" if begin != 'now' 377 else "CAST(NOW() AT TIME ZONE 'utc' AS TIMESTAMP)" 378 ) 379 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 380 381 elif flavor == 'duckdb': 382 begin = f"CAST({begin} AS TIMESTAMP)" if begin != 'now' else 'NOW()' 383 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 384 385 elif flavor in ('mssql',): 386 if begin_time and begin_time.microsecond != 0: 387 begin = begin[:-4] + "'" 388 begin = f"CAST({begin} AS DATETIME)" if begin != 'now' else 'GETUTCDATE()' 389 da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin 390 391 elif flavor in ('mysql', 'mariadb'): 392 begin = f"CAST({begin} AS DATETIME(6))" if begin != 'now' else 'UTC_TIMESTAMP(6)' 393 da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin) 394 395 elif flavor == 'sqlite': 396 da = f"datetime({begin}, '{number} {datepart}')" 397 398 elif flavor == 'oracle': 399 if begin == 'now': 400 begin = str( 401 datetime.now(timezone.utc).replace(tzinfo=None).strftime('%Y:%m:%d %M:%S.%f') 402 ) 403 elif begin_time: 404 begin = str(begin_time.strftime('%Y-%m-%d %H:%M:%S.%f')) 405 dt_format = 'YYYY-MM-DD HH24:MI:SS.FF' 406 _begin = f"'{begin}'" if begin_time else begin 407 da = ( 408 (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin) 409 + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "") 410 ) 411 return da
Generate a DATEADD
clause depending on database flavor.
Parameters
flavor (str, default
'postgresql'
): SQL database flavor, e.g.'postgresql'
,'sqlite'
.Currently supported flavors:
'postgresql'
'timescaledb'
'citus'
'cockroachdb'
'duckdb'
'mssql'
'mysql'
'mariadb'
'sqlite'
'oracle'
datepart (str, default
'day'
): Which part of the date to modify. Supported values:'year'
'month'
'day'
'hour'
'minute'
'second'
- number (Union[int, float], default
0
): How many units to add to the date part. - begin (Union[str, datetime], default
'now'
): Base datetime to which to add dateparts.
Returns
- The appropriate
DATEADD
string for the corresponding database flavor.
Examples
>>> dateadd_str(
... flavor = 'mssql',
... begin = datetime(2022, 1, 1, 0, 0),
... number = 1,
... )
"DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME))"
>>> dateadd_str(
... flavor = 'postgresql',
... begin = datetime(2022, 1, 1, 0, 0),
... number = 1,
... )
"CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'"
414def test_connection( 415 self, 416 **kw: Any 417 ) -> Union[bool, None]: 418 """ 419 Test if a successful connection to the database may be made. 420 421 Parameters 422 ---------- 423 **kw: 424 The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`. 425 426 Returns 427 ------- 428 `True` if a connection is made, otherwise `False` or `None` in case of failure. 429 430 """ 431 import warnings 432 from meerschaum.connectors.poll import retry_connect 433 _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self} 434 _default_kw.update(kw) 435 with warnings.catch_warnings(): 436 warnings.filterwarnings('ignore', 'Could not') 437 try: 438 return retry_connect(**_default_kw) 439 except Exception as e: 440 return False
Test if a successful connection to the database may be made.
Parameters
- **kw:: The keyword arguments are passed to
meerschaum.connectors.poll.retry_connect
.
Returns
True
if a connection is made, otherwiseFalse
orNone
in case of failure.
443def get_distinct_col_count( 444 col: str, 445 query: str, 446 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 447 debug: bool = False 448 ) -> Optional[int]: 449 """ 450 Returns the number of distinct items in a column of a SQL query. 451 452 Parameters 453 ---------- 454 col: str: 455 The column in the query to count. 456 457 query: str: 458 The SQL query to count from. 459 460 connector: Optional[mrsm.connectors.sql.SQLConnector], default None: 461 The SQLConnector to execute the query. 462 463 debug: bool, default False: 464 Verbosity toggle. 465 466 Returns 467 ------- 468 An `int` of the number of columns in the query or `None` if the query fails. 469 470 """ 471 if connector is None: 472 connector = mrsm.get_connector('sql') 473 474 _col_name = sql_item_name(col, connector.flavor, None) 475 476 _meta_query = ( 477 f""" 478 WITH src AS ( {query} ), 479 dist AS ( SELECT DISTINCT {_col_name} FROM src ) 480 SELECT COUNT(*) FROM dist""" 481 ) if connector.flavor not in ('mysql', 'mariadb') else ( 482 f""" 483 SELECT COUNT(*) 484 FROM ( 485 SELECT DISTINCT {_col_name} 486 FROM ({query}) AS src 487 ) AS dist""" 488 ) 489 490 result = connector.value(_meta_query, debug=debug) 491 try: 492 return int(result) 493 except Exception as e: 494 return None
Returns the number of distinct items in a column of a SQL query.
Parameters
- col (str:): The column in the query to count.
- query (str:): The SQL query to count from.
- connector (Optional[mrsm.connectors.sql.SQLConnector], default None:): The SQLConnector to execute the query.
- debug (bool, default False:): Verbosity toggle.
Returns
- An
int
of the number of columns in the query orNone
if the query fails.
497def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str: 498 """ 499 Parse SQL items depending on the flavor. 500 501 Parameters 502 ---------- 503 item: str : 504 The database item (table, view, etc.) in need of quotes. 505 506 flavor: str : 507 The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.). 508 509 Returns 510 ------- 511 A `str` which contains the input `item` wrapped in the corresponding escape characters. 512 513 Examples 514 -------- 515 >>> sql_item_name('table', 'sqlite') 516 '"table"' 517 >>> sql_item_name('table', 'mssql') 518 "[table]" 519 >>> sql_item_name('table', 'postgresql', schema='abc') 520 '"abc"."table"' 521 522 """ 523 truncated_item = truncate_item_name(str(item), flavor) 524 if flavor == 'oracle': 525 truncated_item = pg_capital(truncated_item) 526 ### NOTE: System-reserved words must be quoted. 527 if truncated_item.lower() in ( 528 'float', 'varchar', 'nvarchar', 'clob', 529 'boolean', 'integer', 'table', 530 ): 531 wrappers = ('"', '"') 532 else: 533 wrappers = ('', '') 534 else: 535 wrappers = table_wrappers.get(flavor, table_wrappers['default']) 536 537 ### NOTE: SQLite does not support schemas. 538 if flavor == 'sqlite': 539 schema = None 540 541 schema_prefix = ( 542 (wrappers[0] + schema + wrappers[1] + '.') 543 if schema is not None 544 else '' 545 ) 546 547 return schema_prefix + wrappers[0] + truncated_item + wrappers[1]
Parse SQL items depending on the flavor.
Parameters
- item (str :): The database item (table, view, etc.) in need of quotes.
- flavor (str :):
The database flavor (
'postgresql'
,'mssql'
,'sqllite'
, etc.).
Returns
- A
str
which contains the inputitem
wrapped in the corresponding escape characters.
Examples
>>> sql_item_name('table', 'sqlite')
'"table"'
>>> sql_item_name('table', 'mssql')
"[table]"
>>> sql_item_name('table', 'postgresql', schema='abc')
'"abc"."table"'
550def pg_capital(s: str) -> str: 551 """ 552 If string contains a capital letter, wrap it in double quotes. 553 554 Parameters 555 ---------- 556 s: str : 557 The string to be escaped. 558 559 Returns 560 ------- 561 The input string wrapped in quotes only if it needs them. 562 563 Examples 564 -------- 565 >>> pg_capital("My Table") 566 '"My Table"' 567 >>> pg_capital('my_table') 568 'my_table' 569 570 """ 571 if '"' in s: 572 return s 573 needs_quotes = s.startswith('_') 574 for c in str(s): 575 if ord(c) < ord('a') or ord(c) > ord('z'): 576 if not c.isdigit() and c != '_': 577 needs_quotes = True 578 break 579 if needs_quotes: 580 return '"' + s + '"' 581 return s
If string contains a capital letter, wrap it in double quotes.
Parameters
s (str :):
The string to be escaped.
Returns
- The input string wrapped in quotes only if it needs them.
Examples
>>> pg_capital("My Table")
'"My Table"'
>>> pg_capital('my_table')
'my_table'
584def oracle_capital(s: str) -> str: 585 """ 586 Capitalize the string of an item on an Oracle database. 587 """ 588 return s
Capitalize the string of an item on an Oracle database.
591def truncate_item_name(item: str, flavor: str) -> str: 592 """ 593 Truncate item names to stay within the database flavor's character limit. 594 595 Parameters 596 ---------- 597 item: str 598 The database item being referenced. This string is the "canonical" name internally. 599 600 flavor: str 601 The flavor of the database on which `item` resides. 602 603 Returns 604 ------- 605 The truncated string. 606 """ 607 from meerschaum.utils.misc import truncate_string_sections 608 return truncate_string_sections( 609 item, max_len=max_name_lens.get(flavor, max_name_lens['default']) 610 )
Truncate item names to stay within the database flavor's character limit.
Parameters
- item (str): The database item being referenced. This string is the "canonical" name internally.
- flavor (str):
The flavor of the database on which
item
resides.
Returns
- The truncated string.
613def build_where( 614 params: Dict[str, Any], 615 connector: Optional[meerschaum.connectors.sql.SQLConnector] = None, 616 with_where: bool = True, 617 ) -> str: 618 """ 619 Build the `WHERE` clause based on the input criteria. 620 621 Parameters 622 ---------- 623 params: Dict[str, Any]: 624 The keywords dictionary to convert into a WHERE clause. 625 If a value is a string which begins with an underscore, negate that value 626 (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`). 627 A value of `_None` will be interpreted as `IS NOT NULL`. 628 629 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 630 The Meerschaum SQLConnector that will be executing the query. 631 The connector is used to extract the SQL dialect. 632 633 with_where: bool, default True: 634 If `True`, include the leading `'WHERE'` string. 635 636 Returns 637 ------- 638 A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor. 639 640 Examples 641 -------- 642 ``` 643 >>> print(build_where({'foo': [1, 2, 3]})) 644 645 WHERE 646 "foo" IN ('1', '2', '3') 647 ``` 648 """ 649 import json 650 from meerschaum.config.static import STATIC_CONFIG 651 from meerschaum.utils.warnings import warn 652 from meerschaum.utils.dtypes import value_is_null, none_if_null 653 negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix'] 654 try: 655 params_json = json.dumps(params) 656 except Exception as e: 657 params_json = str(params) 658 bad_words = ['drop ', '--', ';'] 659 for word in bad_words: 660 if word in params_json.lower(): 661 warn(f"Aborting build_where() due to possible SQL injection.") 662 return '' 663 664 if connector is None: 665 from meerschaum import get_connector 666 connector = get_connector('sql') 667 where = "" 668 leading_and = "\n AND " 669 for key, value in params.items(): 670 _key = sql_item_name(key, connector.flavor, None) 671 ### search across a list (i.e. IN syntax) 672 if isinstance(value, Iterable) and not isinstance(value, (dict, str)): 673 includes = [ 674 none_if_null(item) 675 for item in value 676 if not str(item).startswith(negation_prefix) 677 ] 678 null_includes = [item for item in includes if item is None] 679 not_null_includes = [item for item in includes if item is not None] 680 excludes = [ 681 none_if_null(str(item)[len(negation_prefix):]) 682 for item in value 683 if str(item).startswith(negation_prefix) 684 ] 685 null_excludes = [item for item in excludes if item is None] 686 not_null_excludes = [item for item in excludes if item is not None] 687 688 if includes: 689 where += f"{leading_and}(" 690 if not_null_includes: 691 where += f"{_key} IN (" 692 for item in not_null_includes: 693 quoted_item = str(item).replace("'", "''") 694 where += f"'{quoted_item}', " 695 where = where[:-2] + ")" 696 if null_includes: 697 where += ("\n OR " if not_null_includes else "") + f"{_key} IS NULL" 698 if includes: 699 where += ")" 700 701 if excludes: 702 where += f"{leading_and}(" 703 if not_null_excludes: 704 where += f"{_key} NOT IN (" 705 for item in not_null_excludes: 706 quoted_item = str(item).replace("'", "''") 707 where += f"'{quoted_item}', " 708 where = where[:-2] + ")" 709 if null_excludes: 710 where += ("\n AND " if not_null_excludes else "") + f"{_key} IS NOT NULL" 711 if excludes: 712 where += ")" 713 714 continue 715 716 ### search a dictionary 717 elif isinstance(value, dict): 718 import json 719 where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'") 720 continue 721 722 eq_sign = '=' 723 is_null = 'IS NULL' 724 if value_is_null(str(value).lstrip(negation_prefix)): 725 value = ( 726 (negation_prefix + 'None') 727 if str(value).startswith(negation_prefix) 728 else None 729 ) 730 if str(value).startswith(negation_prefix): 731 value = str(value)[len(negation_prefix):] 732 eq_sign = '!=' 733 if value_is_null(value): 734 value = None 735 is_null = 'IS NOT NULL' 736 quoted_value = str(value).replace("'", "''") 737 where += ( 738 f"{leading_and}{_key} " 739 + (is_null if value is None else f"{eq_sign} '{quoted_value}'") 740 ) 741 742 if len(where) > 1: 743 where = ("\nWHERE\n " if with_where else '') + where[len(leading_and):] 744 return where
Build the WHERE
clause based on the input criteria.
Parameters
- params (Dict[str, Any]:):
The keywords dictionary to convert into a WHERE clause.
If a value is a string which begins with an underscore, negate that value
(e.g.
!=
instead of=
orNOT IN
instead ofIN
). A value of_None
will be interpreted asIS NOT NULL
. - connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The Meerschaum SQLConnector that will be executing the query. The connector is used to extract the SQL dialect.
- with_where (bool, default True:):
If
True
, include the leading'WHERE'
string.
Returns
- A
str
of theWHERE
clause from the inputparams
dictionary for the connector's flavor.
Examples
>>> print(build_where({'foo': [1, 2, 3]}))
WHERE
"foo" IN ('1', '2', '3')
747def table_exists( 748 table: str, 749 connector: mrsm.connectors.sql.SQLConnector, 750 schema: Optional[str] = None, 751 debug: bool = False, 752 ) -> bool: 753 """Check if a table exists. 754 755 Parameters 756 ---------- 757 table: str: 758 The name of the table in question. 759 760 connector: mrsm.connectors.sql.SQLConnector 761 The connector to the database which holds the table. 762 763 schema: Optional[str], default None 764 Optionally specify the table schema. 765 Defaults to `connector.schema`. 766 767 debug: bool, default False : 768 Verbosity toggle. 769 770 Returns 771 ------- 772 A `bool` indicating whether or not the table exists on the database. 773 774 """ 775 sqlalchemy = mrsm.attempt_import('sqlalchemy') 776 schema = schema or connector.schema 777 insp = sqlalchemy.inspect(connector.engine) 778 truncated_table_name = truncate_item_name(str(table), connector.flavor) 779 return insp.has_table(truncated_table_name, schema=schema)
Check if a table exists.
Parameters
- table (str:): The name of the table in question.
- connector (mrsm.connectors.sql.SQLConnector): The connector to the database which holds the table.
- schema (Optional[str], default None):
Optionally specify the table schema.
Defaults to
connector.schema
. - debug (bool, default False :): Verbosity toggle.
Returns
- A
bool
indicating whether or not the table exists on the database.
782def get_sqlalchemy_table( 783 table: str, 784 connector: Optional[meerschaum.connectors.sql.SQLConnector] = None, 785 schema: Optional[str] = None, 786 refresh: bool = False, 787 debug: bool = False, 788 ) -> 'sqlalchemy.Table': 789 """ 790 Construct a SQLAlchemy table from its name. 791 792 Parameters 793 ---------- 794 table: str 795 The name of the table on the database. Does not need to be escaped. 796 797 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 798 The connector to the database which holds the table. 799 800 schema: Optional[str], default None 801 Specify on which schema the table resides. 802 Defaults to the schema set in `connector`. 803 804 refresh: bool, default False 805 If `True`, rebuild the cached table object. 806 807 debug: bool, default False: 808 Verbosity toggle. 809 810 Returns 811 ------- 812 A `sqlalchemy.Table` object for the table. 813 814 """ 815 if connector is None: 816 from meerschaum import get_connector 817 connector = get_connector('sql') 818 819 from meerschaum.connectors.sql.tables import get_tables 820 from meerschaum.utils.packages import attempt_import 821 from meerschaum.utils.warnings import warn 822 if refresh: 823 connector.metadata.clear() 824 tables = get_tables(mrsm_instance=connector, debug=debug, create=False) 825 sqlalchemy = attempt_import('sqlalchemy') 826 truncated_table_name = truncate_item_name(str(table), connector.flavor) 827 table_kwargs = { 828 'autoload_with': connector.engine, 829 } 830 if schema: 831 table_kwargs['schema'] = schema 832 833 if refresh or truncated_table_name not in tables: 834 try: 835 tables[truncated_table_name] = sqlalchemy.Table( 836 truncated_table_name, 837 connector.metadata, 838 **table_kwargs 839 ) 840 except sqlalchemy.exc.NoSuchTableError as e: 841 warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.") 842 return None 843 return tables[truncated_table_name]
Construct a SQLAlchemy table from its name.
Parameters
- table (str): The name of the table on the database. Does not need to be escaped.
- connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The connector to the database which holds the table.
- schema (Optional[str], default None):
Specify on which schema the table resides.
Defaults to the schema set in
connector
. - refresh (bool, default False):
If
True
, rebuild the cached table object. - debug (bool, default False:): Verbosity toggle.
Returns
- A
sqlalchemy.Table
object for the table.
846def get_table_cols_types( 847 table: str, 848 connectable: Union[ 849 'mrsm.connectors.sql.SQLConnector', 850 'sqlalchemy.orm.session.Session', 851 'sqlalchemy.engine.base.Engine' 852 ], 853 flavor: Optional[str] = None, 854 schema: Optional[str] = None, 855 database: Optional[str] = None, 856 debug: bool = False, 857 ) -> Dict[str, str]: 858 """ 859 Return a dictionary mapping a table's columns to data types. 860 This is useful for inspecting tables creating during a not-yet-committed session. 861 862 NOTE: This may return incorrect columns if the schema is not explicitly stated. 863 Use this function if you are confident the table name is unique or if you have 864 and explicit schema. 865 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 866 867 Parameters 868 ---------- 869 table: str 870 The name of the table (unquoted). 871 872 connectable: Union[ 873 'mrsm.connectors.sql.SQLConnector', 874 'sqlalchemy.orm.session.Session', 875 ] 876 The connection object used to fetch the columns and types. 877 878 flavor: Optional[str], default None 879 The database dialect flavor to use for the query. 880 If omitted, default to `connectable.flavor`. 881 882 schema: Optional[str], default None 883 If provided, restrict the query to this schema. 884 885 database: Optional[str]. default None 886 If provided, restrict the query to this database. 887 888 Returns 889 ------- 890 A dictionary mapping column names to data types. 891 """ 892 from meerschaum.connectors import SQLConnector 893 from meerschaum.utils.misc import filter_keywords 894 sqlalchemy = mrsm.attempt_import('sqlalchemy') 895 flavor = flavor or getattr(connectable, 'flavor', None) 896 if not flavor: 897 raise ValueError(f"Please provide a database flavor.") 898 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 899 raise ValueError(f"You must provide a SQLConnector when using DuckDB.") 900 if flavor in NO_SCHEMA_FLAVORS: 901 schema = None 902 if schema is None: 903 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 904 if flavor in ('sqlite', 'duckdb', 'oracle'): 905 database = None 906 table_trunc = truncate_item_name(table, flavor=flavor) 907 table_lower = table.lower() 908 table_upper = table.upper() 909 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 910 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 911 912 cols_types_query = sqlalchemy.text( 913 columns_types_queries.get( 914 flavor, 915 columns_types_queries['default'] 916 ).format( 917 table = table, 918 table_trunc = table_trunc, 919 table_lower = table_lower, 920 table_lower_trunc = table_lower_trunc, 921 table_upper = table_upper, 922 table_upper_trunc = table_upper_trunc, 923 ) 924 ) 925 926 cols = ['database', 'schema', 'table', 'column', 'type'] 927 result_cols_ix = dict(enumerate(cols)) 928 929 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 930 if not debug_kwargs and debug: 931 dprint(cols_types_query) 932 933 try: 934 result_rows = ( 935 [ 936 row 937 for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall() 938 ] 939 if flavor != 'duckdb' 940 else [ 941 tuple([doc[col] for col in cols]) 942 for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records') 943 ] 944 ) 945 cols_types_docs = [ 946 { 947 result_cols_ix[i]: val 948 for i, val in enumerate(row) 949 } 950 for row in result_rows 951 ] 952 cols_types_docs_filtered = [ 953 doc 954 for doc in cols_types_docs 955 if ( 956 ( 957 not schema 958 or doc['schema'] == schema 959 ) 960 and 961 ( 962 not database 963 or doc['database'] == database 964 ) 965 ) 966 ] 967 968 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 969 if cols_types_docs and not cols_types_docs_filtered: 970 cols_types_docs_filtered = cols_types_docs 971 972 return { 973 ( 974 doc['column'] 975 if flavor != 'oracle' else ( 976 ( 977 doc['column'].lower() 978 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 979 else doc['column'] 980 ) 981 ) 982 ): doc['type'].upper() 983 for doc in cols_types_docs_filtered 984 } 985 except Exception as e: 986 warn(f"Failed to fetch columns for table '{table}':\n{e}") 987 return {}
Return a dictionary mapping a table's columns to data types. This is useful for inspecting tables creating during a not-yet-committed session.
NOTE: This may return incorrect columns if the schema is not explicitly stated.
Use this function if you are confident the table name is unique or if you have
and explicit schema.
To use the configured schema, get the columns from get_sqlalchemy_table()
instead.
Parameters
- table (str): The name of the table (unquoted).
- connectable (Union[): 'mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session',
- ]: The connection object used to fetch the columns and types.
- flavor (Optional[str], default None):
The database dialect flavor to use for the query.
If omitted, default to
connectable.flavor
. - schema (Optional[str], default None): If provided, restrict the query to this schema.
- database (Optional[str]. default None): If provided, restrict the query to this database.
Returns
- A dictionary mapping column names to data types.
990def get_update_queries( 991 target: str, 992 patch: str, 993 connectable: Union[ 994 mrsm.connectors.sql.SQLConnector, 995 'sqlalchemy.orm.session.Session' 996 ], 997 join_cols: Iterable[str], 998 flavor: Optional[str] = None, 999 upsert: bool = False, 1000 datetime_col: Optional[str] = None, 1001 schema: Optional[str] = None, 1002 patch_schema: Optional[str] = None, 1003 debug: bool = False, 1004 ) -> List[str]: 1005 """ 1006 Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table. 1007 1008 Parameters 1009 ---------- 1010 target: str 1011 The name of the target table. 1012 1013 patch: str 1014 The name of the patch table. This should have the same shape as the target. 1015 1016 connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session] 1017 The `SQLConnector` or SQLAlchemy session which will later execute the queries. 1018 1019 join_cols: List[str] 1020 The columns to use to join the patch to the target. 1021 1022 flavor: Optional[str], default None 1023 If using a SQLAlchemy session, provide the expected database flavor. 1024 1025 upsert: bool, default False 1026 If `True`, return an upsert query rather than an update. 1027 1028 datetime_col: Optional[str], default None 1029 If provided, bound the join query using this column as the datetime index. 1030 This must be present on both tables. 1031 1032 schema: Optional[str], default None 1033 If provided, use this schema when quoting the target table. 1034 Defaults to `connector.schema`. 1035 1036 patch_schema: Optional[str], default None 1037 If provided, use this schema when quoting the patch table. 1038 Defaults to `schema`. 1039 1040 debug: bool, default False 1041 Verbosity toggle. 1042 1043 Returns 1044 ------- 1045 A list of query strings to perform the update operation. 1046 """ 1047 from meerschaum.connectors import SQLConnector 1048 from meerschaum.utils.debug import dprint 1049 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1050 flavor = flavor or (connectable.flavor if isinstance(connectable, SQLConnector) else None) 1051 if not flavor: 1052 raise ValueError("Provide a flavor if using a SQLAlchemy session.") 1053 if ( 1054 flavor == 'sqlite' 1055 and isinstance(connectable, SQLConnector) 1056 and connectable.db_version < '3.33.0' 1057 ): 1058 flavor = 'sqlite_delete_insert' 1059 flavor_key = (f'{flavor}-upsert' if upsert else flavor) 1060 base_queries = update_queries.get( 1061 flavor_key, 1062 update_queries['default'] 1063 ) 1064 if not isinstance(base_queries, list): 1065 base_queries = [base_queries] 1066 schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None) 1067 patch_schema = patch_schema or schema 1068 target_table_columns = get_table_cols_types( 1069 target, 1070 connectable, 1071 flavor = flavor, 1072 schema = schema, 1073 debug = debug, 1074 ) 1075 patch_table_columns = get_table_cols_types( 1076 patch, 1077 connectable, 1078 flavor = flavor, 1079 schema = patch_schema, 1080 debug = debug, 1081 ) 1082 1083 patch_cols_str = ', '.join( 1084 [ 1085 sql_item_name(col, flavor) 1086 for col in patch_table_columns 1087 ] 1088 ) 1089 join_cols_str = ', '.join( 1090 [ 1091 sql_item_name(col, flavor) 1092 for col in join_cols 1093 ] 1094 ) 1095 1096 value_cols = [] 1097 join_cols_types = [] 1098 if debug: 1099 dprint(f"target_table_columns:") 1100 mrsm.pprint(target_table_columns) 1101 for c_name, c_type in target_table_columns.items(): 1102 if c_name not in patch_table_columns: 1103 continue 1104 if flavor in DB_FLAVORS_CAST_DTYPES: 1105 c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type) 1106 ( 1107 join_cols_types 1108 if c_name in join_cols 1109 else value_cols 1110 ).append((c_name, c_type)) 1111 if debug: 1112 dprint(f"value_cols: {value_cols}") 1113 1114 if not join_cols_types: 1115 return [] 1116 if not value_cols and not upsert: 1117 return [] 1118 1119 coalesce_join_cols_str = ', '.join( 1120 [ 1121 'COALESCE(' 1122 + sql_item_name(c_name, flavor) 1123 + ', ' 1124 + get_null_replacement(c_type, flavor) 1125 + ')' 1126 for c_name, c_type in join_cols_types 1127 ] 1128 ) 1129 1130 update_or_nothing = ('UPDATE' if value_cols else 'NOTHING') 1131 1132 def sets_subquery(l_prefix: str, r_prefix: str): 1133 if not value_cols: 1134 return '' 1135 return 'SET ' + ',\n'.join([ 1136 ( 1137 l_prefix + sql_item_name(c_name, flavor, None) 1138 + ' = ' 1139 + ('CAST(' if flavor != 'sqlite' else '') 1140 + r_prefix 1141 + sql_item_name(c_name, flavor, None) 1142 + (' AS ' if flavor != 'sqlite' else '') 1143 + (c_type.replace('_', ' ') if flavor != 'sqlite' else '') 1144 + (')' if flavor != 'sqlite' else '') 1145 ) for c_name, c_type in value_cols 1146 ]) 1147 1148 def and_subquery(l_prefix: str, r_prefix: str): 1149 return '\nAND\n'.join([ 1150 ( 1151 "COALESCE(" 1152 + l_prefix 1153 + sql_item_name(c_name, flavor, None) 1154 + ", " 1155 + get_null_replacement(c_type, flavor) 1156 + ")" 1157 + ' = ' 1158 + "COALESCE(" 1159 + r_prefix 1160 + sql_item_name(c_name, flavor, None) 1161 + ", " 1162 + get_null_replacement(c_type, flavor) 1163 + ")" 1164 ) for c_name, c_type in join_cols_types 1165 ]) 1166 1167 target_table_name = sql_item_name(target, flavor, schema) 1168 patch_table_name = sql_item_name(patch, flavor, patch_schema) 1169 dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None 1170 date_bounds_subquery = ( 1171 f""" 1172 f.{dt_col_name} >= (SELECT MIN({dt_col_name}) FROM {patch_table_name}) 1173 AND f.{dt_col_name} <= (SELECT MAX({dt_col_name}) FROM {patch_table_name}) 1174 """ 1175 if datetime_col 1176 else "1 = 1" 1177 ) 1178 1179 return [ 1180 base_query.format( 1181 sets_subquery_none = sets_subquery('', 'p.'), 1182 sets_subquery_none_excluded = sets_subquery('', 'EXCLUDED.'), 1183 sets_subquery_f = sets_subquery('f.', 'p.'), 1184 and_subquery_f = and_subquery('p.', 'f.'), 1185 and_subquery_t = and_subquery('p.', 't.'), 1186 target_table_name = target_table_name, 1187 patch_table_name = patch_table_name, 1188 patch_cols_str = patch_cols_str, 1189 date_bounds_subquery = date_bounds_subquery, 1190 join_cols_str = join_cols_str, 1191 coalesce_join_cols_str = coalesce_join_cols_str, 1192 update_or_nothing = update_or_nothing, 1193 ) 1194 for base_query in base_queries 1195 ]
Build a list of MERGE
, UPDATE
, DELETE
/INSERT
queries to apply a patch to target table.
Parameters
- target (str): The name of the target table.
- patch (str): The name of the patch table. This should have the same shape as the target.
- connectable (Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session]):
The
SQLConnector
or SQLAlchemy session which will later execute the queries. - join_cols (List[str]): The columns to use to join the patch to the target.
- flavor (Optional[str], default None): If using a SQLAlchemy session, provide the expected database flavor.
- upsert (bool, default False):
If
True
, return an upsert query rather than an update. - datetime_col (Optional[str], default None): If provided, bound the join query using this column as the datetime index. This must be present on both tables.
- schema (Optional[str], default None):
If provided, use this schema when quoting the target table.
Defaults to
connector.schema
. - patch_schema (Optional[str], default None):
If provided, use this schema when quoting the patch table.
Defaults to
schema
. - debug (bool, default False): Verbosity toggle.
Returns
- A list of query strings to perform the update operation.
1198def get_null_replacement(typ: str, flavor: str) -> str: 1199 """ 1200 Return a value that may temporarily be used in place of NULL for this type. 1201 1202 Parameters 1203 ---------- 1204 typ: str 1205 The typ to be converted to NULL. 1206 1207 flavor: str 1208 The database flavor for which this value will be used. 1209 1210 Returns 1211 ------- 1212 A value which may stand in place of NULL for this type. 1213 `'None'` is returned if a value cannot be determined. 1214 """ 1215 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1216 if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'): 1217 return '-987654321' 1218 if 'bool' in typ.lower(): 1219 bool_typ = ( 1220 PD_TO_DB_DTYPES_FLAVORS 1221 .get('bool', {}) 1222 .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default']) 1223 ) 1224 if flavor in DB_FLAVORS_CAST_DTYPES: 1225 bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ) 1226 val_to_cast = ( 1227 -987654321 1228 if flavor in ('mysql', 'mariadb', 'sqlite', 'mssql') 1229 else 0 1230 ) 1231 return f'CAST({val_to_cast} AS {bool_typ})' 1232 if 'time' in typ.lower() or 'date' in typ.lower(): 1233 return dateadd_str(flavor=flavor, begin='1900-01-01') 1234 if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',): 1235 return '-987654321.0' 1236 return ('n' if flavor == 'oracle' else '') + "'-987654321'"
Return a value that may temporarily be used in place of NULL for this type.
Parameters
- typ (str): The typ to be converted to NULL.
- flavor (str): The database flavor for which this value will be used.
Returns
- A value which may stand in place of NULL for this type.
'None'
is returned if a value cannot be determined.
1239def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]: 1240 """ 1241 Fetch the database version if possible. 1242 """ 1243 version_name = sql_item_name('version', conn.flavor, None) 1244 version_query = version_queries.get( 1245 conn.flavor, 1246 version_queries['default'] 1247 ).format(version_name=version_name) 1248 return conn.value(version_query, debug=debug)
Fetch the database version if possible.
1251def get_rename_table_queries( 1252 old_table: str, 1253 new_table: str, 1254 flavor: str, 1255 schema: Optional[str] = None, 1256 ) -> List[str]: 1257 """ 1258 Return queries to alter a table's name. 1259 1260 Parameters 1261 ---------- 1262 old_table: str 1263 The unquoted name of the old table. 1264 1265 new_table: str 1266 The unquoted name of the new table. 1267 1268 flavor: str 1269 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1270 1271 schema: Optional[str], default None 1272 The schema on which the table resides. 1273 1274 Returns 1275 ------- 1276 A list of `ALTER TABLE` or equivalent queries for the database flavor. 1277 """ 1278 old_table_name = sql_item_name(old_table, flavor, schema) 1279 new_table_name = sql_item_name(new_table, flavor, None) 1280 tmp_table = '_tmp_rename_' + new_table 1281 tmp_table_name = sql_item_name(tmp_table, flavor, schema) 1282 if flavor == 'mssql': 1283 return [f"EXEC sp_rename '{old_table}', '{new_table}'"] 1284 1285 if flavor == 'duckdb': 1286 return [ 1287 get_create_table_query(f"SELECT * FROM {old_table_name}", tmp_table, 'duckdb', schema), 1288 get_create_table_query(f"SELECT * FROM {tmp_table_name}", new_table, 'duckdb', schema), 1289 f"DROP TABLE {tmp_table_name}", 1290 f"DROP TABLE {old_table_name}", 1291 ] 1292 1293 return [f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}"]
Return queries to alter a table's name.
Parameters
- old_table (str): The unquoted name of the old table.
- new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
. - schema (Optional[str], default None): The schema on which the table resides.
Returns
- A list of
ALTER TABLE
or equivalent queries for the database flavor.
1296def get_create_table_query( 1297 query: str, 1298 new_table: str, 1299 flavor: str, 1300 schema: Optional[str] = None, 1301 ) -> str: 1302 """ 1303 Return a query to create a new table from a `SELECT` query. 1304 1305 Parameters 1306 ---------- 1307 query: str 1308 The select query to use for the creation of the table. 1309 1310 new_table: str 1311 The unquoted name of the new table. 1312 1313 flavor: str 1314 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1315 1316 schema: Optional[str], default None 1317 The schema on which the table will reside. 1318 1319 Returns 1320 ------- 1321 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1322 """ 1323 import textwrap 1324 create_cte = 'create_query' 1325 create_cte_name = sql_item_name(create_cte, flavor, None) 1326 new_table_name = sql_item_name(new_table, flavor, schema) 1327 if flavor in ('mssql',): 1328 query = query.lstrip() 1329 original_query = query 1330 if 'with ' in query.lower(): 1331 final_select_ix = query.lower().rfind('select') 1332 def_name = query[len('WITH '):].split(' ', maxsplit=1)[0] 1333 return ( 1334 query[:final_select_ix].rstrip() + ',\n' 1335 + f"{create_cte_name} AS (\n" 1336 + query[final_select_ix:] 1337 + "\n)\n" 1338 + f"SELECT *\nINTO {new_table_name}\nFROM {create_cte_name}" 1339 ) 1340 1341 create_table_query = f""" 1342 SELECT * 1343 INTO {new_table_name} 1344 FROM ({query}) AS {create_cte_name} 1345 """ 1346 elif flavor in (None,): 1347 create_table_query = f""" 1348 WITH {create_cte_name} AS ({query}) 1349 CREATE TABLE {new_table_name} AS 1350 SELECT * 1351 FROM {create_cte_name} 1352 """ 1353 elif flavor in ('sqlite', 'mysql', 'mariadb', 'duckdb', 'oracle'): 1354 create_table_query = f""" 1355 CREATE TABLE {new_table_name} AS 1356 SELECT * 1357 FROM ({query})""" + (f""" AS {create_cte_name}""" if flavor != 'oracle' else '') + """ 1358 """ 1359 else: 1360 create_table_query = f""" 1361 SELECT * 1362 INTO {new_table_name} 1363 FROM ({query}) AS {create_cte_name} 1364 """ 1365 1366 return textwrap.dedent(create_table_query)
Return a query to create a new table from a SELECT
query.
Parameters
- query (str): The select query to use for the creation of the table.
- new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
. - schema (Optional[str], default None): The schema on which the table will reside.
Returns
- A
CREATE TABLE
(orSELECT INTO
) query for the database flavor.
1369def format_cte_subquery( 1370 sub_query: str, 1371 flavor: str, 1372 sub_name: str = 'src', 1373 cols_to_select: Union[List[str], str] = '*', 1374 ) -> str: 1375 """ 1376 Given a subquery, build a wrapper query that selects from the CTE subquery. 1377 1378 Parameters 1379 ---------- 1380 sub_query: str 1381 The subquery to wrap. 1382 1383 flavor: str 1384 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1385 1386 sub_name: str, default 'src' 1387 If possible, give this name to the CTE (must be unquoted). 1388 1389 cols_to_select: Union[List[str], str], default '' 1390 If specified, choose which columns to select from the CTE. 1391 If a list of strings is provided, each item will be quoted and joined with commas. 1392 If a string is given, assume it is quoted and insert it into the query. 1393 1394 Returns 1395 ------- 1396 A wrapper query that selects from the CTE. 1397 """ 1398 import textwrap 1399 quoted_sub_name = sql_item_name(sub_name, flavor, None) 1400 cols_str = ( 1401 cols_to_select 1402 if isinstance(cols_to_select, str) 1403 else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select]) 1404 ) 1405 return textwrap.dedent( 1406 f""" 1407 SELECT {cols_str} 1408 FROM ({sub_query})""" 1409 + (f' AS {quoted_sub_name}' if flavor != 'oracle' else '') + """ 1410 """ 1411 )
Given a subquery, build a wrapper query that selects from the CTE subquery.
Parameters
- sub_query (str): The subquery to wrap.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
. - sub_name (str, default 'src'): If possible, give this name to the CTE (must be unquoted).
- cols_to_select (Union[List[str], str], default ''): If specified, choose which columns to select from the CTE. If a list of strings is provided, each item will be quoted and joined with commas. If a string is given, assume it is quoted and insert it into the query.
Returns
- A wrapper query that selects from the CTE.
1414def session_execute( 1415 session: 'sqlalchemy.orm.session.Session', 1416 queries: Union[List[str], str], 1417 with_results: bool = False, 1418 debug: bool = False, 1419 ) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]: 1420 """ 1421 Similar to `SQLConnector.exec_queries()`, execute a list of queries 1422 and roll back when one fails. 1423 1424 Parameters 1425 ---------- 1426 session: sqlalchemy.orm.session.Session 1427 A SQLAlchemy session representing a transaction. 1428 1429 queries: Union[List[str], str] 1430 A query or list of queries to be executed. 1431 If a query fails, roll back the session. 1432 1433 with_results: bool, default False 1434 If `True`, return a list of result objects. 1435 1436 Returns 1437 ------- 1438 A `SuccessTuple` indicating the queries were successfully executed. 1439 If `with_results`, return the `SuccessTuple` and a list of results. 1440 """ 1441 sqlalchemy = mrsm.attempt_import('sqlalchemy') 1442 if not isinstance(queries, list): 1443 queries = [queries] 1444 successes, msgs, results = [], [], [] 1445 for query in queries: 1446 query_text = sqlalchemy.text(query) 1447 fail_msg = f"Failed to execute queries." 1448 try: 1449 result = session.execute(query_text) 1450 query_success = result is not None 1451 query_msg = "Success" if query_success else fail_msg 1452 except Exception as e: 1453 query_success = False 1454 query_msg = f"{fail_msg}\n{e}" 1455 result = None 1456 successes.append(query_success) 1457 msgs.append(query_msg) 1458 results.append(result) 1459 if not query_success: 1460 session.rollback() 1461 break 1462 success, msg = all(successes), '\n'.join(msgs) 1463 if with_results: 1464 return (success, msg), results 1465 return success, msg
Similar to SQLConnector.exec_queries()
, execute a list of queries
and roll back when one fails.
Parameters
- session (sqlalchemy.orm.session.Session): A SQLAlchemy session representing a transaction.
- queries (Union[List[str], str]): A query or list of queries to be executed. If a query fails, roll back the session.
- with_results (bool, default False):
If
True
, return a list of result objects.
Returns
- A
SuccessTuple
indicating the queries were successfully executed. - If
with_results
, return theSuccessTuple
and a list of results.