meerschaum.utils.sql
Flavor-specific SQL tools.
1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# vim:fenc=utf-8 4 5""" 6Flavor-specific SQL tools. 7""" 8 9from __future__ import annotations 10 11from datetime import datetime, timezone, timedelta 12import meerschaum as mrsm 13from meerschaum.utils.typing import Optional, Dict, Any, Union, List, Iterable, Tuple 14### Preserve legacy imports. 15from meerschaum.utils.dtypes.sql import ( 16 DB_TO_PD_DTYPES, 17 PD_TO_DB_DTYPES_FLAVORS, 18 get_pd_type_from_db_type as get_pd_type, 19 get_db_type_from_pd_type as get_db_type, 20 TIMEZONE_NAIVE_FLAVORS, 21) 22from meerschaum.utils.warnings import warn 23from meerschaum.utils.debug import dprint 24 25test_queries = { 26 'default' : 'SELECT 1', 27 'oracle' : 'SELECT 1 FROM DUAL', 28 'informix' : 'SELECT COUNT(*) FROM systables', 29 'hsqldb' : 'SELECT 1 FROM INFORMATION_SCHEMA.SYSTEM_USERS', 30} 31### `table_name` is the escaped name of the table. 32### `table` is the unescaped name of the table. 33exists_queries = { 34 'default': "SELECT COUNT(*) FROM {table_name} WHERE 1 = 0", 35} 36version_queries = { 37 'default': "SELECT VERSION() AS {version_name}", 38 'sqlite': "SELECT SQLITE_VERSION() AS {version_name}", 39 'mssql': "SELECT @@version", 40 'oracle': "SELECT version from PRODUCT_COMPONENT_VERSION WHERE rownum = 1", 41} 42SKIP_IF_EXISTS_FLAVORS = {'mssql', 'oracle'} 43DROP_IF_EXISTS_FLAVORS = { 44 'timescaledb', 'postgresql', 'citus', 'mssql', 'mysql', 'mariadb', 'sqlite', 45} 46DROP_INDEX_IF_EXISTS_FLAVORS = { 47 'mssql', 'timescaledb', 'postgresql', 'sqlite', 'citus', 48} 49SKIP_AUTO_INCREMENT_FLAVORS = {'citus', 'duckdb'} 50COALESCE_UNIQUE_INDEX_FLAVORS = {'timescaledb', 'postgresql', 'citus'} 51UPDATE_QUERIES = { 52 'default': """ 53 UPDATE {target_table_name} AS f 54 {sets_subquery_none} 55 FROM {target_table_name} AS t 56 INNER JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p 57 ON 58 {and_subquery_t} 59 WHERE 60 {and_subquery_f} 61 AND 62 {date_bounds_subquery} 63 """, 64 'timescaledb-upsert': """ 65 INSERT INTO {target_table_name} ({patch_cols_str}) 66 SELECT {patch_cols_str} 67 FROM {patch_table_name} 68 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 69 """, 70 'postgresql-upsert': """ 71 INSERT INTO {target_table_name} ({patch_cols_str}) 72 SELECT {patch_cols_str} 73 FROM {patch_table_name} 74 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 75 """, 76 'citus-upsert': """ 77 INSERT INTO {target_table_name} ({patch_cols_str}) 78 SELECT {patch_cols_str} 79 FROM {patch_table_name} 80 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 81 """, 82 'cockroachdb-upsert': """ 83 INSERT INTO {target_table_name} ({patch_cols_str}) 84 SELECT {patch_cols_str} 85 FROM {patch_table_name} 86 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 87 """, 88 'mysql': """ 89 UPDATE {target_table_name} AS f 90 JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p 91 ON 92 {and_subquery_f} 93 {sets_subquery_f} 94 WHERE 95 {date_bounds_subquery} 96 """, 97 'mysql-upsert': """ 98 INSERT {ignore}INTO {target_table_name} ({patch_cols_str}) 99 SELECT {patch_cols_str} 100 FROM {patch_table_name} 101 {on_duplicate_key_update} 102 {cols_equal_values} 103 """, 104 'mariadb': """ 105 UPDATE {target_table_name} AS f 106 JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p 107 ON 108 {and_subquery_f} 109 {sets_subquery_f} 110 WHERE 111 {date_bounds_subquery} 112 """, 113 'mariadb-upsert': """ 114 INSERT {ignore}INTO {target_table_name} ({patch_cols_str}) 115 SELECT {patch_cols_str} 116 FROM {patch_table_name} 117 {on_duplicate_key_update} 118 {cols_equal_values} 119 """, 120 'mssql': """ 121 {with_temp_date_bounds} 122 MERGE {target_table_name} f 123 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 124 ON 125 {and_subquery_f} 126 AND 127 {date_bounds_subquery} 128 WHEN MATCHED THEN 129 UPDATE 130 {sets_subquery_none}; 131 """, 132 'mssql-upsert': [ 133 "{identity_insert_on}", 134 """ 135 {with_temp_date_bounds} 136 MERGE {target_table_name} f 137 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 138 ON 139 {and_subquery_f} 140 AND 141 {date_bounds_subquery}{when_matched_update_sets_subquery_none} 142 WHEN NOT MATCHED THEN 143 INSERT ({patch_cols_str}) 144 VALUES ({patch_cols_prefixed_str}); 145 """, 146 "{identity_insert_off}", 147 ], 148 'oracle': """ 149 MERGE INTO {target_table_name} f 150 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 151 ON ( 152 {and_subquery_f} 153 AND 154 {date_bounds_subquery} 155 ) 156 WHEN MATCHED THEN 157 UPDATE 158 {sets_subquery_none} 159 """, 160 'oracle-upsert': """ 161 MERGE INTO {target_table_name} f 162 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 163 ON ( 164 {and_subquery_f} 165 AND 166 {date_bounds_subquery} 167 ){when_matched_update_sets_subquery_none} 168 WHEN NOT MATCHED THEN 169 INSERT ({patch_cols_str}) 170 VALUES ({patch_cols_prefixed_str}) 171 """, 172 'sqlite-upsert': """ 173 INSERT INTO {target_table_name} ({patch_cols_str}) 174 SELECT {patch_cols_str} 175 FROM {patch_table_name} 176 WHERE true 177 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 178 """, 179 'sqlite_delete_insert': [ 180 """ 181 DELETE FROM {target_table_name} AS f 182 WHERE ROWID IN ( 183 SELECT t.ROWID 184 FROM {target_table_name} AS t 185 INNER JOIN (SELECT * FROM {patch_table_name}) AS p 186 ON {and_subquery_t} 187 ); 188 """, 189 """ 190 INSERT INTO {target_table_name} AS f 191 SELECT {patch_cols_str} FROM {patch_table_name} AS p 192 """, 193 ], 194} 195columns_types_queries = { 196 'default': """ 197 SELECT 198 table_catalog AS database, 199 table_schema AS schema, 200 table_name AS table, 201 column_name AS column, 202 data_type AS type, 203 numeric_precision, 204 numeric_scale 205 FROM information_schema.columns 206 WHERE table_name IN ('{table}', '{table_trunc}') 207 """, 208 'sqlite': """ 209 SELECT 210 '' "database", 211 '' "schema", 212 m.name "table", 213 p.name "column", 214 p.type "type" 215 FROM sqlite_master m 216 LEFT OUTER JOIN pragma_table_info(m.name) p 217 ON m.name <> p.name 218 WHERE m.type = 'table' 219 AND m.name IN ('{table}', '{table_trunc}') 220 """, 221 'mssql': """ 222 SELECT 223 TABLE_CATALOG AS [database], 224 TABLE_SCHEMA AS [schema], 225 TABLE_NAME AS [table], 226 COLUMN_NAME AS [column], 227 DATA_TYPE AS [type], 228 NUMERIC_PRECISION AS [numeric_precision], 229 NUMERIC_SCALE AS [numeric_scale] 230 FROM {db_prefix}INFORMATION_SCHEMA.COLUMNS 231 WHERE TABLE_NAME IN ( 232 '{table}', 233 '{table_trunc}' 234 ) 235 236 """, 237 'mysql': """ 238 SELECT 239 TABLE_SCHEMA `database`, 240 TABLE_SCHEMA `schema`, 241 TABLE_NAME `table`, 242 COLUMN_NAME `column`, 243 DATA_TYPE `type`, 244 NUMERIC_PRECISION `numeric_precision`, 245 NUMERIC_SCALE `numeric_scale` 246 FROM INFORMATION_SCHEMA.COLUMNS 247 WHERE TABLE_NAME IN ('{table}', '{table_trunc}') 248 """, 249 'mariadb': """ 250 SELECT 251 TABLE_SCHEMA `database`, 252 TABLE_SCHEMA `schema`, 253 TABLE_NAME `table`, 254 COLUMN_NAME `column`, 255 DATA_TYPE `type`, 256 NUMERIC_PRECISION `numeric_precision`, 257 NUMERIC_SCALE `numeric_scale` 258 FROM INFORMATION_SCHEMA.COLUMNS 259 WHERE TABLE_NAME IN ('{table}', '{table_trunc}') 260 """, 261 'oracle': """ 262 SELECT 263 NULL AS "database", 264 NULL AS "schema", 265 TABLE_NAME AS "table", 266 COLUMN_NAME AS "column", 267 DATA_TYPE AS "type", 268 DATA_PRECISION AS "numeric_precision", 269 DATA_SCALE AS "numeric_scale" 270 FROM all_tab_columns 271 WHERE TABLE_NAME IN ( 272 '{table}', 273 '{table_trunc}', 274 '{table_lower}', 275 '{table_lower_trunc}', 276 '{table_upper}', 277 '{table_upper_trunc}' 278 ) 279 """, 280} 281hypertable_queries = { 282 'timescaledb': 'SELECT hypertable_size(\'{table_name}\')', 283 'citus': 'SELECT citus_table_size(\'{table_name}\')', 284} 285columns_indices_queries = { 286 'default': """ 287 SELECT 288 current_database() AS "database", 289 n.nspname AS "schema", 290 t.relname AS "table", 291 c.column_name AS "column", 292 i.relname AS "index", 293 CASE WHEN con.contype = 'p' THEN 'PRIMARY KEY' ELSE 'INDEX' END AS "index_type" 294 FROM pg_class t 295 INNER JOIN pg_index AS ix 296 ON t.oid = ix.indrelid 297 INNER JOIN pg_class AS i 298 ON i.oid = ix.indexrelid 299 INNER JOIN pg_namespace AS n 300 ON n.oid = t.relnamespace 301 INNER JOIN pg_attribute AS a 302 ON a.attnum = ANY(ix.indkey) 303 AND a.attrelid = t.oid 304 INNER JOIN information_schema.columns AS c 305 ON c.column_name = a.attname 306 AND c.table_name = t.relname 307 AND c.table_schema = n.nspname 308 LEFT JOIN pg_constraint AS con 309 ON con.conindid = i.oid 310 AND con.contype = 'p' 311 WHERE 312 t.relname IN ('{table}', '{table_trunc}') 313 AND n.nspname = '{schema}' 314 """, 315 'sqlite': """ 316 WITH indexed_columns AS ( 317 SELECT 318 '{table}' AS table_name, 319 pi.name AS column_name, 320 i.name AS index_name, 321 'INDEX' AS index_type 322 FROM 323 sqlite_master AS i, 324 pragma_index_info(i.name) AS pi 325 WHERE 326 i.type = 'index' 327 AND i.tbl_name = '{table}' 328 ), 329 primary_key_columns AS ( 330 SELECT 331 '{table}' AS table_name, 332 ti.name AS column_name, 333 'PRIMARY_KEY' AS index_name, 334 'PRIMARY KEY' AS index_type 335 FROM 336 pragma_table_info('{table}') AS ti 337 WHERE 338 ti.pk > 0 339 ) 340 SELECT 341 NULL AS "database", 342 NULL AS "schema", 343 "table_name" AS "table", 344 "column_name" AS "column", 345 "index_name" AS "index", 346 "index_type" 347 FROM indexed_columns 348 UNION ALL 349 SELECT 350 NULL AS "database", 351 NULL AS "schema", 352 table_name AS "table", 353 column_name AS "column", 354 index_name AS "index", 355 index_type 356 FROM primary_key_columns 357 """, 358 'mssql': """ 359 SELECT 360 NULL AS [database], 361 s.name AS [schema], 362 t.name AS [table], 363 c.name AS [column], 364 i.name AS [index], 365 CASE 366 WHEN kc.type = 'PK' THEN 'PRIMARY KEY' 367 ELSE 'INDEX' 368 END AS [index_type], 369 CASE 370 WHEN i.type = 1 THEN CAST(1 AS BIT) 371 ELSE CAST(0 AS BIT) 372 END AS [clustered] 373 FROM 374 sys.schemas s 375 INNER JOIN sys.tables t 376 ON s.schema_id = t.schema_id 377 INNER JOIN sys.indexes i 378 ON t.object_id = i.object_id 379 INNER JOIN sys.index_columns ic 380 ON i.object_id = ic.object_id 381 AND i.index_id = ic.index_id 382 INNER JOIN sys.columns c 383 ON ic.object_id = c.object_id 384 AND ic.column_id = c.column_id 385 LEFT JOIN sys.key_constraints kc 386 ON kc.parent_object_id = i.object_id 387 AND kc.type = 'PK' 388 AND kc.name = i.name 389 WHERE 390 t.name IN ('{table}', '{table_trunc}') 391 AND s.name = '{schema}' 392 AND i.type IN (1, 2) 393 """, 394 'oracle': """ 395 SELECT 396 NULL AS "database", 397 ic.table_owner AS "schema", 398 ic.table_name AS "table", 399 ic.column_name AS "column", 400 i.index_name AS "index", 401 CASE 402 WHEN c.constraint_type = 'P' THEN 'PRIMARY KEY' 403 WHEN i.uniqueness = 'UNIQUE' THEN 'UNIQUE INDEX' 404 ELSE 'INDEX' 405 END AS index_type 406 FROM 407 all_ind_columns ic 408 INNER JOIN all_indexes i 409 ON ic.index_name = i.index_name 410 AND ic.table_owner = i.owner 411 LEFT JOIN all_constraints c 412 ON i.index_name = c.constraint_name 413 AND i.table_owner = c.owner 414 AND c.constraint_type = 'P' 415 WHERE ic.table_name IN ( 416 '{table}', 417 '{table_trunc}', 418 '{table_upper}', 419 '{table_upper_trunc}' 420 ) 421 """, 422 'mysql': """ 423 SELECT 424 TABLE_SCHEMA AS `database`, 425 TABLE_SCHEMA AS `schema`, 426 TABLE_NAME AS `table`, 427 COLUMN_NAME AS `column`, 428 INDEX_NAME AS `index`, 429 CASE 430 WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY' 431 ELSE 'INDEX' 432 END AS `index_type` 433 FROM 434 information_schema.STATISTICS 435 WHERE 436 TABLE_NAME IN ('{table}', '{table_trunc}') 437 """, 438 'mariadb': """ 439 SELECT 440 TABLE_SCHEMA AS `database`, 441 TABLE_SCHEMA AS `schema`, 442 TABLE_NAME AS `table`, 443 COLUMN_NAME AS `column`, 444 INDEX_NAME AS `index`, 445 CASE 446 WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY' 447 ELSE 'INDEX' 448 END AS `index_type` 449 FROM 450 information_schema.STATISTICS 451 WHERE 452 TABLE_NAME IN ('{table}', '{table_trunc}') 453 """, 454} 455reset_autoincrement_queries: Dict[str, Union[str, List[str]]] = { 456 'default': """ 457 SELECT SETVAL(pg_get_serial_sequence('{table}', '{column}'), {val}) 458 FROM {table_name} 459 """, 460 'mssql': """ 461 DBCC CHECKIDENT ('{table}', RESEED, {val}) 462 """, 463 'mysql': """ 464 ALTER TABLE {table_name} AUTO_INCREMENT = {val} 465 """, 466 'mariadb': """ 467 ALTER TABLE {table_name} AUTO_INCREMENT = {val} 468 """, 469 'sqlite': """ 470 UPDATE sqlite_sequence 471 SET seq = {val} 472 WHERE name = '{table}' 473 """, 474 'oracle': ( 475 "ALTER TABLE {table_name} MODIFY {column_name} " 476 "GENERATED BY DEFAULT ON NULL AS IDENTITY (START WITH {val_plus_1})" 477 ), 478} 479table_wrappers = { 480 'default' : ('"', '"'), 481 'timescaledb': ('"', '"'), 482 'citus' : ('"', '"'), 483 'duckdb' : ('"', '"'), 484 'postgresql' : ('"', '"'), 485 'sqlite' : ('"', '"'), 486 'mysql' : ('`', '`'), 487 'mariadb' : ('`', '`'), 488 'mssql' : ('[', ']'), 489 'cockroachdb': ('"', '"'), 490 'oracle' : ('"', '"'), 491} 492max_name_lens = { 493 'default' : 64, 494 'mssql' : 128, 495 'oracle' : 30, 496 'postgresql' : 64, 497 'timescaledb': 64, 498 'citus' : 64, 499 'cockroachdb': 64, 500 'sqlite' : 1024, ### Probably more, but 1024 seems more than reasonable. 501 'mysql' : 64, 502 'mariadb' : 64, 503} 504json_flavors = {'postgresql', 'timescaledb', 'citus', 'cockroachdb'} 505NO_SCHEMA_FLAVORS = {'oracle', 'sqlite', 'mysql', 'mariadb', 'duckdb'} 506DEFAULT_SCHEMA_FLAVORS = { 507 'postgresql': 'public', 508 'timescaledb': 'public', 509 'citus': 'public', 510 'cockroachdb': 'public', 511 'mysql': 'mysql', 512 'mariadb': 'mysql', 513 'mssql': 'dbo', 514} 515OMIT_NULLSFIRST_FLAVORS = {'mariadb', 'mysql', 'mssql'} 516 517SINGLE_ALTER_TABLE_FLAVORS = {'duckdb', 'sqlite', 'mssql', 'oracle'} 518NO_CTE_FLAVORS = {'mysql', 'mariadb'} 519NO_SELECT_INTO_FLAVORS = {'sqlite', 'oracle', 'mysql', 'mariadb', 'duckdb'} 520 521 522def clean(substring: str) -> str: 523 """ 524 Ensure a substring is clean enough to be inserted into a SQL query. 525 Raises an exception when banned words are used. 526 """ 527 from meerschaum.utils.warnings import error 528 banned_symbols = [';', '--', 'drop ',] 529 for symbol in banned_symbols: 530 if symbol in str(substring).lower(): 531 error(f"Invalid string: '{substring}'") 532 533 534def dateadd_str( 535 flavor: str = 'postgresql', 536 datepart: str = 'day', 537 number: Union[int, float] = 0, 538 begin: Union[str, datetime, int] = 'now', 539 db_type: Optional[str] = None, 540) -> str: 541 """ 542 Generate a `DATEADD` clause depending on database flavor. 543 544 Parameters 545 ---------- 546 flavor: str, default `'postgresql'` 547 SQL database flavor, e.g. `'postgresql'`, `'sqlite'`. 548 549 Currently supported flavors: 550 551 - `'postgresql'` 552 - `'timescaledb'` 553 - `'citus'` 554 - `'cockroachdb'` 555 - `'duckdb'` 556 - `'mssql'` 557 - `'mysql'` 558 - `'mariadb'` 559 - `'sqlite'` 560 - `'oracle'` 561 562 datepart: str, default `'day'` 563 Which part of the date to modify. Supported values: 564 565 - `'year'` 566 - `'month'` 567 - `'day'` 568 - `'hour'` 569 - `'minute'` 570 - `'second'` 571 572 number: Union[int, float], default `0` 573 How many units to add to the date part. 574 575 begin: Union[str, datetime], default `'now'` 576 Base datetime to which to add dateparts. 577 578 db_type: Optional[str], default None 579 If provided, cast the datetime string as the type. 580 Otherwise, infer this from the input datetime value. 581 582 Returns 583 ------- 584 The appropriate `DATEADD` string for the corresponding database flavor. 585 586 Examples 587 -------- 588 >>> dateadd_str( 589 ... flavor='mssql', 590 ... begin=datetime(2022, 1, 1, 0, 0), 591 ... number=1, 592 ... ) 593 "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))" 594 >>> dateadd_str( 595 ... flavor='postgresql', 596 ... begin=datetime(2022, 1, 1, 0, 0), 597 ... number=1, 598 ... ) 599 "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'" 600 601 """ 602 from meerschaum.utils.packages import attempt_import 603 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, get_pd_type_from_db_type 604 dateutil_parser = attempt_import('dateutil.parser') 605 if 'int' in str(type(begin)).lower(): 606 num_str = str(begin) 607 if number is not None and number != 0: 608 num_str += ( 609 f' + {number}' 610 if number > 0 611 else f" - {number * -1}" 612 ) 613 return num_str 614 if not begin: 615 return '' 616 617 _original_begin = begin 618 begin_time = None 619 ### Sanity check: make sure `begin` is a valid datetime before we inject anything. 620 if not isinstance(begin, datetime): 621 try: 622 begin_time = dateutil_parser.parse(begin) 623 except Exception: 624 begin_time = None 625 else: 626 begin_time = begin 627 628 ### Unable to parse into a datetime. 629 if begin_time is None: 630 ### Throw an error if banned symbols are included in the `begin` string. 631 clean(str(begin)) 632 ### If begin is a valid datetime, wrap it in quotes. 633 else: 634 if isinstance(begin, datetime) and begin.tzinfo is not None: 635 begin = begin.astimezone(timezone.utc) 636 begin = ( 637 f"'{begin.replace(tzinfo=None)}'" 638 if isinstance(begin, datetime) and flavor in TIMEZONE_NAIVE_FLAVORS 639 else f"'{begin}'" 640 ) 641 642 dt_is_utc = ( 643 begin_time.tzinfo is not None 644 if begin_time is not None 645 else ('+' in str(begin) or '-' in str(begin).split(':', maxsplit=1)[-1]) 646 ) 647 if db_type: 648 db_type_is_utc = 'utc' in get_pd_type_from_db_type(db_type).lower() 649 dt_is_utc = dt_is_utc or db_type_is_utc 650 db_type = db_type or get_db_type_from_pd_type( 651 ('datetime64[ns, UTC]' if dt_is_utc else 'datetime64[ns]'), 652 flavor=flavor, 653 ) 654 655 da = "" 656 if flavor in ('postgresql', 'timescaledb', 'cockroachdb', 'citus'): 657 begin = ( 658 f"CAST({begin} AS {db_type})" if begin != 'now' 659 else f"CAST(NOW() AT TIME ZONE 'utc' AS {db_type})" 660 ) 661 if dt_is_utc: 662 begin += " AT TIME ZONE 'UTC'" 663 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 664 665 elif flavor == 'duckdb': 666 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'NOW()' 667 if dt_is_utc: 668 begin += " AT TIME ZONE 'UTC'" 669 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 670 671 elif flavor in ('mssql',): 672 if begin_time and begin_time.microsecond != 0 and not dt_is_utc: 673 begin = begin[:-4] + "'" 674 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'GETUTCDATE()' 675 if dt_is_utc: 676 begin += " AT TIME ZONE 'UTC'" 677 da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin 678 679 elif flavor in ('mysql', 'mariadb'): 680 begin = ( 681 f"CAST({begin} AS DATETIME(6))" 682 if begin != 'now' 683 else 'UTC_TIMESTAMP(6)' 684 ) 685 da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin) 686 687 elif flavor == 'sqlite': 688 da = f"datetime({begin}, '{number} {datepart}')" 689 690 elif flavor == 'oracle': 691 if begin == 'now': 692 begin = str( 693 datetime.now(timezone.utc).replace(tzinfo=None).strftime(r'%Y:%m:%d %M:%S.%f') 694 ) 695 elif begin_time: 696 begin = str(begin_time.strftime(r'%Y-%m-%d %H:%M:%S.%f')) 697 dt_format = 'YYYY-MM-DD HH24:MI:SS.FF' 698 _begin = f"'{begin}'" if begin_time else begin 699 da = ( 700 (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin) 701 + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "") 702 ) 703 return da 704 705 706def test_connection( 707 self, 708 **kw: Any 709) -> Union[bool, None]: 710 """ 711 Test if a successful connection to the database may be made. 712 713 Parameters 714 ---------- 715 **kw: 716 The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`. 717 718 Returns 719 ------- 720 `True` if a connection is made, otherwise `False` or `None` in case of failure. 721 722 """ 723 import warnings 724 from meerschaum.connectors.poll import retry_connect 725 _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self} 726 _default_kw.update(kw) 727 with warnings.catch_warnings(): 728 warnings.filterwarnings('ignore', 'Could not') 729 try: 730 return retry_connect(**_default_kw) 731 except Exception: 732 return False 733 734 735def get_distinct_col_count( 736 col: str, 737 query: str, 738 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 739 debug: bool = False 740) -> Optional[int]: 741 """ 742 Returns the number of distinct items in a column of a SQL query. 743 744 Parameters 745 ---------- 746 col: str: 747 The column in the query to count. 748 749 query: str: 750 The SQL query to count from. 751 752 connector: Optional[mrsm.connectors.sql.SQLConnector], default None: 753 The SQLConnector to execute the query. 754 755 debug: bool, default False: 756 Verbosity toggle. 757 758 Returns 759 ------- 760 An `int` of the number of columns in the query or `None` if the query fails. 761 762 """ 763 if connector is None: 764 connector = mrsm.get_connector('sql') 765 766 _col_name = sql_item_name(col, connector.flavor, None) 767 768 _meta_query = ( 769 f""" 770 WITH src AS ( {query} ), 771 dist AS ( SELECT DISTINCT {_col_name} FROM src ) 772 SELECT COUNT(*) FROM dist""" 773 ) if connector.flavor not in ('mysql', 'mariadb') else ( 774 f""" 775 SELECT COUNT(*) 776 FROM ( 777 SELECT DISTINCT {_col_name} 778 FROM ({query}) AS src 779 ) AS dist""" 780 ) 781 782 result = connector.value(_meta_query, debug=debug) 783 try: 784 return int(result) 785 except Exception: 786 return None 787 788 789def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str: 790 """ 791 Parse SQL items depending on the flavor. 792 793 Parameters 794 ---------- 795 item: str 796 The database item (table, view, etc.) in need of quotes. 797 798 flavor: str 799 The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.). 800 801 schema: Optional[str], default None 802 If provided, prefix the table name with the schema. 803 804 Returns 805 ------- 806 A `str` which contains the input `item` wrapped in the corresponding escape characters. 807 808 Examples 809 -------- 810 >>> sql_item_name('table', 'sqlite') 811 '"table"' 812 >>> sql_item_name('table', 'mssql') 813 "[table]" 814 >>> sql_item_name('table', 'postgresql', schema='abc') 815 '"abc"."table"' 816 817 """ 818 truncated_item = truncate_item_name(str(item), flavor) 819 if flavor == 'oracle': 820 truncated_item = pg_capital(truncated_item, quote_capitals=True) 821 ### NOTE: System-reserved words must be quoted. 822 if truncated_item.lower() in ( 823 'float', 'varchar', 'nvarchar', 'clob', 824 'boolean', 'integer', 'table', 'row', 825 ): 826 wrappers = ('"', '"') 827 else: 828 wrappers = ('', '') 829 else: 830 wrappers = table_wrappers.get(flavor, table_wrappers['default']) 831 832 ### NOTE: SQLite does not support schemas. 833 if flavor == 'sqlite': 834 schema = None 835 elif flavor == 'mssql' and str(item).startswith('#'): 836 schema = None 837 838 schema_prefix = ( 839 (wrappers[0] + schema + wrappers[1] + '.') 840 if schema is not None 841 else '' 842 ) 843 844 return schema_prefix + wrappers[0] + truncated_item + wrappers[1] 845 846 847def pg_capital(s: str, quote_capitals: bool = True) -> str: 848 """ 849 If string contains a capital letter, wrap it in double quotes. 850 851 Parameters 852 ---------- 853 s: str 854 The string to be escaped. 855 856 quote_capitals: bool, default True 857 If `False`, do not quote strings with contain only a mix of capital and lower-case letters. 858 859 Returns 860 ------- 861 The input string wrapped in quotes only if it needs them. 862 863 Examples 864 -------- 865 >>> pg_capital("My Table") 866 '"My Table"' 867 >>> pg_capital('my_table') 868 'my_table' 869 870 """ 871 if s.startswith('"') and s.endswith('"'): 872 return s 873 874 s = s.replace('"', '') 875 876 needs_quotes = s.startswith('_') 877 if not needs_quotes: 878 for c in s: 879 if c == '_': 880 continue 881 882 if not c.isalnum() or (quote_capitals and c.isupper()): 883 needs_quotes = True 884 break 885 886 if needs_quotes: 887 return '"' + s + '"' 888 889 return s 890 891 892def oracle_capital(s: str) -> str: 893 """ 894 Capitalize the string of an item on an Oracle database. 895 """ 896 return s 897 898 899def truncate_item_name(item: str, flavor: str) -> str: 900 """ 901 Truncate item names to stay within the database flavor's character limit. 902 903 Parameters 904 ---------- 905 item: str 906 The database item being referenced. This string is the "canonical" name internally. 907 908 flavor: str 909 The flavor of the database on which `item` resides. 910 911 Returns 912 ------- 913 The truncated string. 914 """ 915 from meerschaum.utils.misc import truncate_string_sections 916 return truncate_string_sections( 917 item, max_len=max_name_lens.get(flavor, max_name_lens['default']) 918 ) 919 920 921def build_where( 922 params: Dict[str, Any], 923 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 924 with_where: bool = True, 925) -> str: 926 """ 927 Build the `WHERE` clause based on the input criteria. 928 929 Parameters 930 ---------- 931 params: Dict[str, Any]: 932 The keywords dictionary to convert into a WHERE clause. 933 If a value is a string which begins with an underscore, negate that value 934 (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`). 935 A value of `_None` will be interpreted as `IS NOT NULL`. 936 937 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 938 The Meerschaum SQLConnector that will be executing the query. 939 The connector is used to extract the SQL dialect. 940 941 with_where: bool, default True: 942 If `True`, include the leading `'WHERE'` string. 943 944 Returns 945 ------- 946 A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor. 947 948 Examples 949 -------- 950 ``` 951 >>> print(build_where({'foo': [1, 2, 3]})) 952 953 WHERE 954 "foo" IN ('1', '2', '3') 955 ``` 956 """ 957 import json 958 from meerschaum.config.static import STATIC_CONFIG 959 from meerschaum.utils.warnings import warn 960 from meerschaum.utils.dtypes import value_is_null, none_if_null 961 negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix'] 962 try: 963 params_json = json.dumps(params) 964 except Exception as e: 965 params_json = str(params) 966 bad_words = ['drop ', '--', ';'] 967 for word in bad_words: 968 if word in params_json.lower(): 969 warn(f"Aborting build_where() due to possible SQL injection.") 970 return '' 971 972 if connector is None: 973 from meerschaum import get_connector 974 connector = get_connector('sql') 975 where = "" 976 leading_and = "\n AND " 977 for key, value in params.items(): 978 _key = sql_item_name(key, connector.flavor, None) 979 ### search across a list (i.e. IN syntax) 980 if isinstance(value, Iterable) and not isinstance(value, (dict, str)): 981 includes = [ 982 none_if_null(item) 983 for item in value 984 if not str(item).startswith(negation_prefix) 985 ] 986 null_includes = [item for item in includes if item is None] 987 not_null_includes = [item for item in includes if item is not None] 988 excludes = [ 989 none_if_null(str(item)[len(negation_prefix):]) 990 for item in value 991 if str(item).startswith(negation_prefix) 992 ] 993 null_excludes = [item for item in excludes if item is None] 994 not_null_excludes = [item for item in excludes if item is not None] 995 996 if includes: 997 where += f"{leading_and}(" 998 if not_null_includes: 999 where += f"{_key} IN (" 1000 for item in not_null_includes: 1001 quoted_item = str(item).replace("'", "''") 1002 where += f"'{quoted_item}', " 1003 where = where[:-2] + ")" 1004 if null_includes: 1005 where += ("\n OR " if not_null_includes else "") + f"{_key} IS NULL" 1006 if includes: 1007 where += ")" 1008 1009 if excludes: 1010 where += f"{leading_and}(" 1011 if not_null_excludes: 1012 where += f"{_key} NOT IN (" 1013 for item in not_null_excludes: 1014 quoted_item = str(item).replace("'", "''") 1015 where += f"'{quoted_item}', " 1016 where = where[:-2] + ")" 1017 if null_excludes: 1018 where += ("\n AND " if not_null_excludes else "") + f"{_key} IS NOT NULL" 1019 if excludes: 1020 where += ")" 1021 1022 continue 1023 1024 ### search a dictionary 1025 elif isinstance(value, dict): 1026 import json 1027 where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'") 1028 continue 1029 1030 eq_sign = '=' 1031 is_null = 'IS NULL' 1032 if value_is_null(str(value).lstrip(negation_prefix)): 1033 value = ( 1034 (negation_prefix + 'None') 1035 if str(value).startswith(negation_prefix) 1036 else None 1037 ) 1038 if str(value).startswith(negation_prefix): 1039 value = str(value)[len(negation_prefix):] 1040 eq_sign = '!=' 1041 if value_is_null(value): 1042 value = None 1043 is_null = 'IS NOT NULL' 1044 quoted_value = str(value).replace("'", "''") 1045 where += ( 1046 f"{leading_and}{_key} " 1047 + (is_null if value is None else f"{eq_sign} '{quoted_value}'") 1048 ) 1049 1050 if len(where) > 1: 1051 where = ("\nWHERE\n " if with_where else '') + where[len(leading_and):] 1052 return where 1053 1054 1055def table_exists( 1056 table: str, 1057 connector: mrsm.connectors.sql.SQLConnector, 1058 schema: Optional[str] = None, 1059 debug: bool = False, 1060) -> bool: 1061 """Check if a table exists. 1062 1063 Parameters 1064 ---------- 1065 table: str: 1066 The name of the table in question. 1067 1068 connector: mrsm.connectors.sql.SQLConnector 1069 The connector to the database which holds the table. 1070 1071 schema: Optional[str], default None 1072 Optionally specify the table schema. 1073 Defaults to `connector.schema`. 1074 1075 debug: bool, default False : 1076 Verbosity toggle. 1077 1078 Returns 1079 ------- 1080 A `bool` indicating whether or not the table exists on the database. 1081 """ 1082 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1083 schema = schema or connector.schema 1084 insp = sqlalchemy.inspect(connector.engine) 1085 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1086 return insp.has_table(truncated_table_name, schema=schema) 1087 1088 1089def get_sqlalchemy_table( 1090 table: str, 1091 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 1092 schema: Optional[str] = None, 1093 refresh: bool = False, 1094 debug: bool = False, 1095) -> Union['sqlalchemy.Table', None]: 1096 """ 1097 Construct a SQLAlchemy table from its name. 1098 1099 Parameters 1100 ---------- 1101 table: str 1102 The name of the table on the database. Does not need to be escaped. 1103 1104 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 1105 The connector to the database which holds the table. 1106 1107 schema: Optional[str], default None 1108 Specify on which schema the table resides. 1109 Defaults to the schema set in `connector`. 1110 1111 refresh: bool, default False 1112 If `True`, rebuild the cached table object. 1113 1114 debug: bool, default False: 1115 Verbosity toggle. 1116 1117 Returns 1118 ------- 1119 A `sqlalchemy.Table` object for the table. 1120 1121 """ 1122 if connector is None: 1123 from meerschaum import get_connector 1124 connector = get_connector('sql') 1125 1126 if connector.flavor == 'duckdb': 1127 return None 1128 1129 from meerschaum.connectors.sql.tables import get_tables 1130 from meerschaum.utils.packages import attempt_import 1131 from meerschaum.utils.warnings import warn 1132 if refresh: 1133 connector.metadata.clear() 1134 tables = get_tables(mrsm_instance=connector, debug=debug, create=False) 1135 sqlalchemy = attempt_import('sqlalchemy', lazy=False) 1136 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1137 table_kwargs = { 1138 'autoload_with': connector.engine, 1139 } 1140 if schema: 1141 table_kwargs['schema'] = schema 1142 1143 if refresh or truncated_table_name not in tables: 1144 try: 1145 tables[truncated_table_name] = sqlalchemy.Table( 1146 truncated_table_name, 1147 connector.metadata, 1148 **table_kwargs 1149 ) 1150 except sqlalchemy.exc.NoSuchTableError: 1151 warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.") 1152 return None 1153 return tables[truncated_table_name] 1154 1155 1156def get_table_cols_types( 1157 table: str, 1158 connectable: Union[ 1159 'mrsm.connectors.sql.SQLConnector', 1160 'sqlalchemy.orm.session.Session', 1161 'sqlalchemy.engine.base.Engine' 1162 ], 1163 flavor: Optional[str] = None, 1164 schema: Optional[str] = None, 1165 database: Optional[str] = None, 1166 debug: bool = False, 1167) -> Dict[str, str]: 1168 """ 1169 Return a dictionary mapping a table's columns to data types. 1170 This is useful for inspecting tables creating during a not-yet-committed session. 1171 1172 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1173 Use this function if you are confident the table name is unique or if you have 1174 and explicit schema. 1175 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1176 1177 Parameters 1178 ---------- 1179 table: str 1180 The name of the table (unquoted). 1181 1182 connectable: Union[ 1183 'mrsm.connectors.sql.SQLConnector', 1184 'sqlalchemy.orm.session.Session', 1185 'sqlalchemy.engine.base.Engine' 1186 ] 1187 The connection object used to fetch the columns and types. 1188 1189 flavor: Optional[str], default None 1190 The database dialect flavor to use for the query. 1191 If omitted, default to `connectable.flavor`. 1192 1193 schema: Optional[str], default None 1194 If provided, restrict the query to this schema. 1195 1196 database: Optional[str]. default None 1197 If provided, restrict the query to this database. 1198 1199 Returns 1200 ------- 1201 A dictionary mapping column names to data types. 1202 """ 1203 import textwrap 1204 from meerschaum.connectors import SQLConnector 1205 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1206 flavor = flavor or getattr(connectable, 'flavor', None) 1207 if not flavor: 1208 raise ValueError("Please provide a database flavor.") 1209 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1210 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1211 if flavor in NO_SCHEMA_FLAVORS: 1212 schema = None 1213 if schema is None: 1214 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1215 if flavor in ('sqlite', 'duckdb', 'oracle'): 1216 database = None 1217 table_trunc = truncate_item_name(table, flavor=flavor) 1218 table_lower = table.lower() 1219 table_upper = table.upper() 1220 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1221 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1222 db_prefix = ( 1223 "tempdb." 1224 if flavor == 'mssql' and table.startswith('#') 1225 else "" 1226 ) 1227 1228 cols_types_query = sqlalchemy.text( 1229 textwrap.dedent(columns_types_queries.get( 1230 flavor, 1231 columns_types_queries['default'] 1232 ).format( 1233 table=table, 1234 table_trunc=table_trunc, 1235 table_lower=table_lower, 1236 table_lower_trunc=table_lower_trunc, 1237 table_upper=table_upper, 1238 table_upper_trunc=table_upper_trunc, 1239 db_prefix=db_prefix, 1240 )).lstrip().rstrip() 1241 ) 1242 1243 cols = ['database', 'schema', 'table', 'column', 'type', 'numeric_precision', 'numeric_scale'] 1244 result_cols_ix = dict(enumerate(cols)) 1245 1246 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1247 if not debug_kwargs and debug: 1248 dprint(cols_types_query) 1249 1250 try: 1251 result_rows = ( 1252 [ 1253 row 1254 for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall() 1255 ] 1256 if flavor != 'duckdb' 1257 else [ 1258 tuple([doc[col] for col in cols]) 1259 for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records') 1260 ] 1261 ) 1262 cols_types_docs = [ 1263 { 1264 result_cols_ix[i]: val 1265 for i, val in enumerate(row) 1266 } 1267 for row in result_rows 1268 ] 1269 cols_types_docs_filtered = [ 1270 doc 1271 for doc in cols_types_docs 1272 if ( 1273 ( 1274 not schema 1275 or doc['schema'] == schema 1276 ) 1277 and 1278 ( 1279 not database 1280 or doc['database'] == database 1281 ) 1282 ) 1283 ] 1284 1285 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1286 if cols_types_docs and not cols_types_docs_filtered: 1287 cols_types_docs_filtered = cols_types_docs 1288 1289 return { 1290 ( 1291 doc['column'] 1292 if flavor != 'oracle' else ( 1293 ( 1294 doc['column'].lower() 1295 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1296 else doc['column'] 1297 ) 1298 ) 1299 ): doc['type'].upper() + ( 1300 f'({precision},{scale})' 1301 if ( 1302 (precision := doc.get('numeric_precision', None)) 1303 and 1304 (scale := doc.get('numeric_scale', None)) 1305 ) 1306 else '' 1307 ) 1308 for doc in cols_types_docs_filtered 1309 } 1310 except Exception as e: 1311 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1312 return {} 1313 1314 1315def get_table_cols_indices( 1316 table: str, 1317 connectable: Union[ 1318 'mrsm.connectors.sql.SQLConnector', 1319 'sqlalchemy.orm.session.Session', 1320 'sqlalchemy.engine.base.Engine' 1321 ], 1322 flavor: Optional[str] = None, 1323 schema: Optional[str] = None, 1324 database: Optional[str] = None, 1325 debug: bool = False, 1326) -> Dict[str, List[str]]: 1327 """ 1328 Return a dictionary mapping a table's columns to lists of indices. 1329 This is useful for inspecting tables creating during a not-yet-committed session. 1330 1331 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1332 Use this function if you are confident the table name is unique or if you have 1333 and explicit schema. 1334 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1335 1336 Parameters 1337 ---------- 1338 table: str 1339 The name of the table (unquoted). 1340 1341 connectable: Union[ 1342 'mrsm.connectors.sql.SQLConnector', 1343 'sqlalchemy.orm.session.Session', 1344 'sqlalchemy.engine.base.Engine' 1345 ] 1346 The connection object used to fetch the columns and types. 1347 1348 flavor: Optional[str], default None 1349 The database dialect flavor to use for the query. 1350 If omitted, default to `connectable.flavor`. 1351 1352 schema: Optional[str], default None 1353 If provided, restrict the query to this schema. 1354 1355 database: Optional[str]. default None 1356 If provided, restrict the query to this database. 1357 1358 Returns 1359 ------- 1360 A dictionary mapping column names to a list of indices. 1361 """ 1362 import textwrap 1363 from collections import defaultdict 1364 from meerschaum.connectors import SQLConnector 1365 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1366 flavor = flavor or getattr(connectable, 'flavor', None) 1367 if not flavor: 1368 raise ValueError("Please provide a database flavor.") 1369 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1370 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1371 if flavor in NO_SCHEMA_FLAVORS: 1372 schema = None 1373 if schema is None: 1374 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1375 if flavor in ('sqlite', 'duckdb', 'oracle'): 1376 database = None 1377 table_trunc = truncate_item_name(table, flavor=flavor) 1378 table_lower = table.lower() 1379 table_upper = table.upper() 1380 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1381 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1382 db_prefix = ( 1383 "tempdb." 1384 if flavor == 'mssql' and table.startswith('#') 1385 else "" 1386 ) 1387 1388 cols_indices_query = sqlalchemy.text( 1389 textwrap.dedent(columns_indices_queries.get( 1390 flavor, 1391 columns_indices_queries['default'] 1392 ).format( 1393 table=table, 1394 table_trunc=table_trunc, 1395 table_lower=table_lower, 1396 table_lower_trunc=table_lower_trunc, 1397 table_upper=table_upper, 1398 table_upper_trunc=table_upper_trunc, 1399 db_prefix=db_prefix, 1400 schema=schema, 1401 )).lstrip().rstrip() 1402 ) 1403 1404 cols = ['database', 'schema', 'table', 'column', 'index', 'index_type'] 1405 if flavor == 'mssql': 1406 cols.append('clustered') 1407 result_cols_ix = dict(enumerate(cols)) 1408 1409 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1410 if not debug_kwargs and debug: 1411 dprint(cols_indices_query) 1412 1413 try: 1414 result_rows = ( 1415 [ 1416 row 1417 for row in connectable.execute(cols_indices_query, **debug_kwargs).fetchall() 1418 ] 1419 if flavor != 'duckdb' 1420 else [ 1421 tuple([doc[col] for col in cols]) 1422 for doc in connectable.read(cols_indices_query, debug=debug).to_dict(orient='records') 1423 ] 1424 ) 1425 cols_types_docs = [ 1426 { 1427 result_cols_ix[i]: val 1428 for i, val in enumerate(row) 1429 } 1430 for row in result_rows 1431 ] 1432 cols_types_docs_filtered = [ 1433 doc 1434 for doc in cols_types_docs 1435 if ( 1436 ( 1437 not schema 1438 or doc['schema'] == schema 1439 ) 1440 and 1441 ( 1442 not database 1443 or doc['database'] == database 1444 ) 1445 ) 1446 ] 1447 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1448 if cols_types_docs and not cols_types_docs_filtered: 1449 cols_types_docs_filtered = cols_types_docs 1450 1451 cols_indices = defaultdict(lambda: []) 1452 for doc in cols_types_docs_filtered: 1453 col = ( 1454 doc['column'] 1455 if flavor != 'oracle' 1456 else ( 1457 doc['column'].lower() 1458 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1459 else doc['column'] 1460 ) 1461 ) 1462 index_doc = { 1463 'name': doc.get('index', None), 1464 'type': doc.get('index_type', None) 1465 } 1466 if flavor == 'mssql': 1467 index_doc['clustered'] = doc.get('clustered', None) 1468 cols_indices[col].append(index_doc) 1469 1470 return dict(cols_indices) 1471 except Exception as e: 1472 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1473 return {} 1474 1475 1476def get_update_queries( 1477 target: str, 1478 patch: str, 1479 connectable: Union[ 1480 mrsm.connectors.sql.SQLConnector, 1481 'sqlalchemy.orm.session.Session' 1482 ], 1483 join_cols: Iterable[str], 1484 flavor: Optional[str] = None, 1485 upsert: bool = False, 1486 datetime_col: Optional[str] = None, 1487 schema: Optional[str] = None, 1488 patch_schema: Optional[str] = None, 1489 identity_insert: bool = False, 1490 null_indices: bool = True, 1491 cast_columns: bool = True, 1492 debug: bool = False, 1493) -> List[str]: 1494 """ 1495 Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table. 1496 1497 Parameters 1498 ---------- 1499 target: str 1500 The name of the target table. 1501 1502 patch: str 1503 The name of the patch table. This should have the same shape as the target. 1504 1505 connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session] 1506 The `SQLConnector` or SQLAlchemy session which will later execute the queries. 1507 1508 join_cols: List[str] 1509 The columns to use to join the patch to the target. 1510 1511 flavor: Optional[str], default None 1512 If using a SQLAlchemy session, provide the expected database flavor. 1513 1514 upsert: bool, default False 1515 If `True`, return an upsert query rather than an update. 1516 1517 datetime_col: Optional[str], default None 1518 If provided, bound the join query using this column as the datetime index. 1519 This must be present on both tables. 1520 1521 schema: Optional[str], default None 1522 If provided, use this schema when quoting the target table. 1523 Defaults to `connector.schema`. 1524 1525 patch_schema: Optional[str], default None 1526 If provided, use this schema when quoting the patch table. 1527 Defaults to `schema`. 1528 1529 identity_insert: bool, default False 1530 If `True`, include `SET IDENTITY_INSERT` queries before and after the update queries. 1531 Only applies for MSSQL upserts. 1532 1533 null_indices: bool, default True 1534 If `False`, do not coalesce index columns before joining. 1535 1536 cast_columns: bool, default True 1537 If `False`, do not cast update columns to the target table types. 1538 1539 debug: bool, default False 1540 Verbosity toggle. 1541 1542 Returns 1543 ------- 1544 A list of query strings to perform the update operation. 1545 """ 1546 import textwrap 1547 from meerschaum.connectors import SQLConnector 1548 from meerschaum.utils.debug import dprint 1549 from meerschaum.utils.dtypes import are_dtypes_equal 1550 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES, get_pd_type_from_db_type 1551 flavor = flavor or (connectable.flavor if isinstance(connectable, SQLConnector) else None) 1552 if not flavor: 1553 raise ValueError("Provide a flavor if using a SQLAlchemy session.") 1554 if ( 1555 flavor == 'sqlite' 1556 and isinstance(connectable, SQLConnector) 1557 and connectable.db_version < '3.33.0' 1558 ): 1559 flavor = 'sqlite_delete_insert' 1560 flavor_key = (f'{flavor}-upsert' if upsert else flavor) 1561 base_queries = UPDATE_QUERIES.get( 1562 flavor_key, 1563 UPDATE_QUERIES['default'] 1564 ) 1565 if not isinstance(base_queries, list): 1566 base_queries = [base_queries] 1567 schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None) 1568 patch_schema = patch_schema or schema 1569 target_table_columns = get_table_cols_types( 1570 target, 1571 connectable, 1572 flavor=flavor, 1573 schema=schema, 1574 debug=debug, 1575 ) 1576 patch_table_columns = get_table_cols_types( 1577 patch, 1578 connectable, 1579 flavor=flavor, 1580 schema=patch_schema, 1581 debug=debug, 1582 ) 1583 1584 patch_cols_str = ', '.join( 1585 [ 1586 sql_item_name(col, flavor) 1587 for col in patch_table_columns 1588 ] 1589 ) 1590 patch_cols_prefixed_str = ', '.join( 1591 [ 1592 'p.' + sql_item_name(col, flavor) 1593 for col in patch_table_columns 1594 ] 1595 ) 1596 1597 join_cols_str = ', '.join( 1598 [ 1599 sql_item_name(col, flavor) 1600 for col in join_cols 1601 ] 1602 ) 1603 1604 value_cols = [] 1605 join_cols_types = [] 1606 if debug: 1607 dprint("target_table_columns:") 1608 mrsm.pprint(target_table_columns) 1609 for c_name, c_type in target_table_columns.items(): 1610 if c_name not in patch_table_columns: 1611 continue 1612 if flavor in DB_FLAVORS_CAST_DTYPES: 1613 c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type) 1614 ( 1615 join_cols_types 1616 if c_name in join_cols 1617 else value_cols 1618 ).append((c_name, c_type)) 1619 if debug: 1620 dprint(f"value_cols: {value_cols}") 1621 1622 if not join_cols_types: 1623 return [] 1624 if not value_cols and not upsert: 1625 return [] 1626 1627 coalesce_join_cols_str = ', '.join( 1628 [ 1629 ( 1630 ( 1631 'COALESCE(' 1632 + sql_item_name(c_name, flavor) 1633 + ', ' 1634 + get_null_replacement(c_type, flavor) 1635 + ')' 1636 ) 1637 if null_indices 1638 else sql_item_name(c_name, flavor) 1639 ) 1640 for c_name, c_type in join_cols_types 1641 ] 1642 ) 1643 1644 update_or_nothing = ('UPDATE' if value_cols else 'NOTHING') 1645 1646 def sets_subquery(l_prefix: str, r_prefix: str): 1647 if not value_cols: 1648 return '' 1649 1650 utc_value_cols = { 1651 c_name 1652 for c_name, c_type in value_cols 1653 if ('utc' in get_pd_type_from_db_type(c_type).lower()) 1654 } if flavor not in TIMEZONE_NAIVE_FLAVORS else set() 1655 1656 cast_func_cols = { 1657 c_name: ( 1658 ('', '', '') 1659 if not cast_columns or ( 1660 flavor == 'oracle' 1661 and are_dtypes_equal(get_pd_type_from_db_type(c_type), 'bytes') 1662 ) 1663 else ( 1664 ('CAST(', f" AS {c_type.replace('_', ' ')}", ')' + ( 1665 " AT TIME ZONE 'UTC'" 1666 if c_name in utc_value_cols 1667 else '' 1668 )) 1669 if flavor != 'sqlite' 1670 else ('', '', '') 1671 ) 1672 ) 1673 for c_name, c_type in value_cols 1674 } 1675 return 'SET ' + ',\n'.join([ 1676 ( 1677 l_prefix + sql_item_name(c_name, flavor, None) 1678 + ' = ' 1679 + cast_func_cols[c_name][0] 1680 + r_prefix + sql_item_name(c_name, flavor, None) 1681 + cast_func_cols[c_name][1] 1682 + cast_func_cols[c_name][2] 1683 ) for c_name, c_type in value_cols 1684 ]) 1685 1686 def and_subquery(l_prefix: str, r_prefix: str): 1687 return '\n AND\n '.join([ 1688 ( 1689 ( 1690 "COALESCE(" 1691 + l_prefix 1692 + sql_item_name(c_name, flavor, None) 1693 + ", " 1694 + get_null_replacement(c_type, flavor) 1695 + ")" 1696 + '\n =\n ' 1697 + "COALESCE(" 1698 + r_prefix 1699 + sql_item_name(c_name, flavor, None) 1700 + ", " 1701 + get_null_replacement(c_type, flavor) 1702 + ")" 1703 ) 1704 if null_indices 1705 else ( 1706 l_prefix 1707 + sql_item_name(c_name, flavor, None) 1708 + ' = ' 1709 + r_prefix 1710 + sql_item_name(c_name, flavor, None) 1711 ) 1712 ) for c_name, c_type in join_cols_types 1713 ]) 1714 1715 skip_query_val = "" 1716 target_table_name = sql_item_name(target, flavor, schema) 1717 patch_table_name = sql_item_name(patch, flavor, patch_schema) 1718 dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None 1719 date_bounds_table = patch_table_name if flavor != 'mssql' else '[date_bounds]' 1720 min_dt_col_name = f"MIN({dt_col_name})" if flavor != 'mssql' else '[Min_dt]' 1721 max_dt_col_name = f"MAX({dt_col_name})" if flavor != 'mssql' else '[Max_dt]' 1722 date_bounds_subquery = ( 1723 f"""f.{dt_col_name} >= (SELECT {min_dt_col_name} FROM {date_bounds_table}) 1724 AND 1725 f.{dt_col_name} <= (SELECT {max_dt_col_name} FROM {date_bounds_table})""" 1726 if datetime_col 1727 else "1 = 1" 1728 ) 1729 with_temp_date_bounds = f"""WITH [date_bounds] AS ( 1730 SELECT MIN({dt_col_name}) AS {min_dt_col_name}, MAX({dt_col_name}) AS {max_dt_col_name} 1731 FROM {patch_table_name} 1732 )""" if datetime_col else "" 1733 identity_insert_on = ( 1734 f"SET IDENTITY_INSERT {target_table_name} ON" 1735 if identity_insert 1736 else skip_query_val 1737 ) 1738 identity_insert_off = ( 1739 f"SET IDENTITY_INSERT {target_table_name} OFF" 1740 if identity_insert 1741 else skip_query_val 1742 ) 1743 1744 ### NOTE: MSSQL upserts must exclude the update portion if only upserting indices. 1745 when_matched_update_sets_subquery_none = "" if not value_cols else ( 1746 "\n WHEN MATCHED THEN\n" 1747 f" UPDATE {sets_subquery('', 'p.')}" 1748 ) 1749 1750 cols_equal_values = '\n,'.join( 1751 [ 1752 f"{sql_item_name(c_name, flavor)} = VALUES({sql_item_name(c_name, flavor)})" 1753 for c_name, c_type in value_cols 1754 ] 1755 ) 1756 on_duplicate_key_update = ( 1757 "ON DUPLICATE KEY UPDATE" 1758 if value_cols 1759 else "" 1760 ) 1761 ignore = "IGNORE " if not value_cols else "" 1762 1763 formatted_queries = [ 1764 textwrap.dedent(base_query.format( 1765 sets_subquery_none=sets_subquery('', 'p.'), 1766 sets_subquery_none_excluded=sets_subquery('', 'EXCLUDED.'), 1767 sets_subquery_f=sets_subquery('f.', 'p.'), 1768 and_subquery_f=and_subquery('p.', 'f.'), 1769 and_subquery_t=and_subquery('p.', 't.'), 1770 target_table_name=target_table_name, 1771 patch_table_name=patch_table_name, 1772 patch_cols_str=patch_cols_str, 1773 patch_cols_prefixed_str=patch_cols_prefixed_str, 1774 date_bounds_subquery=date_bounds_subquery, 1775 join_cols_str=join_cols_str, 1776 coalesce_join_cols_str=coalesce_join_cols_str, 1777 update_or_nothing=update_or_nothing, 1778 when_matched_update_sets_subquery_none=when_matched_update_sets_subquery_none, 1779 cols_equal_values=cols_equal_values, 1780 on_duplicate_key_update=on_duplicate_key_update, 1781 ignore=ignore, 1782 with_temp_date_bounds=with_temp_date_bounds, 1783 identity_insert_on=identity_insert_on, 1784 identity_insert_off=identity_insert_off, 1785 )).lstrip().rstrip() 1786 for base_query in base_queries 1787 ] 1788 1789 ### NOTE: Allow for skipping some queries. 1790 return [query for query in formatted_queries if query] 1791 1792 1793def get_null_replacement(typ: str, flavor: str) -> str: 1794 """ 1795 Return a value that may temporarily be used in place of NULL for this type. 1796 1797 Parameters 1798 ---------- 1799 typ: str 1800 The typ to be converted to NULL. 1801 1802 flavor: str 1803 The database flavor for which this value will be used. 1804 1805 Returns 1806 ------- 1807 A value which may stand in place of NULL for this type. 1808 `'None'` is returned if a value cannot be determined. 1809 """ 1810 from meerschaum.utils.dtypes import are_dtypes_equal 1811 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1812 if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'): 1813 return '-987654321' 1814 if 'bool' in typ.lower() or typ.lower() == 'bit': 1815 bool_typ = ( 1816 PD_TO_DB_DTYPES_FLAVORS 1817 .get('bool', {}) 1818 .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default']) 1819 ) 1820 if flavor in DB_FLAVORS_CAST_DTYPES: 1821 bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ) 1822 val_to_cast = ( 1823 -987654321 1824 if flavor in ('mysql', 'mariadb') 1825 else 0 1826 ) 1827 return f'CAST({val_to_cast} AS {bool_typ})' 1828 if 'time' in typ.lower() or 'date' in typ.lower(): 1829 db_type = typ if typ.isupper() else None 1830 return dateadd_str(flavor=flavor, begin='1900-01-01', db_type=db_type) 1831 if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',): 1832 return '-987654321.0' 1833 if flavor == 'oracle' and typ.lower().split('(', maxsplit=1)[0] == 'char': 1834 return "'-987654321'" 1835 if flavor == 'oracle' and typ.lower() in ('blob', 'bytes'): 1836 return '00' 1837 if typ.lower() in ('uniqueidentifier', 'guid', 'uuid'): 1838 magic_val = 'DEADBEEF-ABBA-BABE-CAFE-DECAFC0FFEE5' 1839 if flavor == 'mssql': 1840 return f"CAST('{magic_val}' AS UNIQUEIDENTIFIER)" 1841 return f"'{magic_val}'" 1842 return ('n' if flavor == 'oracle' else '') + "'-987654321'" 1843 1844 1845def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]: 1846 """ 1847 Fetch the database version if possible. 1848 """ 1849 version_name = sql_item_name('version', conn.flavor, None) 1850 version_query = version_queries.get( 1851 conn.flavor, 1852 version_queries['default'] 1853 ).format(version_name=version_name) 1854 return conn.value(version_query, debug=debug) 1855 1856 1857def get_rename_table_queries( 1858 old_table: str, 1859 new_table: str, 1860 flavor: str, 1861 schema: Optional[str] = None, 1862) -> List[str]: 1863 """ 1864 Return queries to alter a table's name. 1865 1866 Parameters 1867 ---------- 1868 old_table: str 1869 The unquoted name of the old table. 1870 1871 new_table: str 1872 The unquoted name of the new table. 1873 1874 flavor: str 1875 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1876 1877 schema: Optional[str], default None 1878 The schema on which the table resides. 1879 1880 Returns 1881 ------- 1882 A list of `ALTER TABLE` or equivalent queries for the database flavor. 1883 """ 1884 old_table_name = sql_item_name(old_table, flavor, schema) 1885 new_table_name = sql_item_name(new_table, flavor, None) 1886 tmp_table = '_tmp_rename_' + new_table 1887 tmp_table_name = sql_item_name(tmp_table, flavor, schema) 1888 if flavor == 'mssql': 1889 return [f"EXEC sp_rename '{old_table}', '{new_table}'"] 1890 1891 if_exists_str = "IF EXISTS" if flavor in DROP_IF_EXISTS_FLAVORS else "" 1892 if flavor == 'duckdb': 1893 return ( 1894 get_create_table_queries( 1895 f"SELECT * FROM {old_table_name}", 1896 tmp_table, 1897 'duckdb', 1898 schema, 1899 ) + get_create_table_queries( 1900 f"SELECT * FROM {tmp_table_name}", 1901 new_table, 1902 'duckdb', 1903 schema, 1904 ) + [ 1905 f"DROP TABLE {if_exists_str} {tmp_table_name}", 1906 f"DROP TABLE {if_exists_str} {old_table_name}", 1907 ] 1908 ) 1909 1910 return [f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}"] 1911 1912 1913def get_create_table_query( 1914 query_or_dtypes: Union[str, Dict[str, str]], 1915 new_table: str, 1916 flavor: str, 1917 schema: Optional[str] = None, 1918) -> str: 1919 """ 1920 NOTE: This function is deprecated. Use `get_create_table_queries()` instead. 1921 1922 Return a query to create a new table from a `SELECT` query. 1923 1924 Parameters 1925 ---------- 1926 query: Union[str, Dict[str, str]] 1927 The select query to use for the creation of the table. 1928 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 1929 1930 new_table: str 1931 The unquoted name of the new table. 1932 1933 flavor: str 1934 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 1935 1936 schema: Optional[str], default None 1937 The schema on which the table will reside. 1938 1939 Returns 1940 ------- 1941 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1942 """ 1943 return get_create_table_queries( 1944 query_or_dtypes, 1945 new_table, 1946 flavor, 1947 schema=schema, 1948 primary_key=None, 1949 )[0] 1950 1951 1952def get_create_table_queries( 1953 query_or_dtypes: Union[str, Dict[str, str]], 1954 new_table: str, 1955 flavor: str, 1956 schema: Optional[str] = None, 1957 primary_key: Optional[str] = None, 1958 primary_key_db_type: Optional[str] = None, 1959 autoincrement: bool = False, 1960 datetime_column: Optional[str] = None, 1961) -> List[str]: 1962 """ 1963 Return a query to create a new table from a `SELECT` query or a `dtypes` dictionary. 1964 1965 Parameters 1966 ---------- 1967 query_or_dtypes: Union[str, Dict[str, str]] 1968 The select query to use for the creation of the table. 1969 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 1970 1971 new_table: str 1972 The unquoted name of the new table. 1973 1974 flavor: str 1975 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 1976 1977 schema: Optional[str], default None 1978 The schema on which the table will reside. 1979 1980 primary_key: Optional[str], default None 1981 If provided, designate this column as the primary key in the new table. 1982 1983 primary_key_db_type: Optional[str], default None 1984 If provided, alter the primary key to this type (to set NOT NULL constraint). 1985 1986 autoincrement: bool, default False 1987 If `True` and `primary_key` is provided, create the `primary_key` column 1988 as an auto-incrementing integer column. 1989 1990 datetime_column: Optional[str], default None 1991 If provided, include this column in the primary key. 1992 Applicable to TimescaleDB only. 1993 1994 Returns 1995 ------- 1996 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1997 """ 1998 if not isinstance(query_or_dtypes, (str, dict)): 1999 raise TypeError("`query_or_dtypes` must be a query or a dtypes dictionary.") 2000 2001 method = ( 2002 _get_create_table_query_from_cte 2003 if isinstance(query_or_dtypes, str) 2004 else _get_create_table_query_from_dtypes 2005 ) 2006 return method( 2007 query_or_dtypes, 2008 new_table, 2009 flavor, 2010 schema=schema, 2011 primary_key=primary_key, 2012 primary_key_db_type=primary_key_db_type, 2013 autoincrement=(autoincrement and flavor not in SKIP_AUTO_INCREMENT_FLAVORS), 2014 datetime_column=datetime_column, 2015 ) 2016 2017 2018def _get_create_table_query_from_dtypes( 2019 dtypes: Dict[str, str], 2020 new_table: str, 2021 flavor: str, 2022 schema: Optional[str] = None, 2023 primary_key: Optional[str] = None, 2024 primary_key_db_type: Optional[str] = None, 2025 autoincrement: bool = False, 2026 datetime_column: Optional[str] = None, 2027) -> List[str]: 2028 """ 2029 Create a new table from a `dtypes` dictionary. 2030 """ 2031 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, AUTO_INCREMENT_COLUMN_FLAVORS 2032 if not dtypes and not primary_key: 2033 raise ValueError(f"Expecting columns for table '{new_table}'.") 2034 2035 if flavor in SKIP_AUTO_INCREMENT_FLAVORS: 2036 autoincrement = False 2037 2038 cols_types = ( 2039 [ 2040 ( 2041 primary_key, 2042 get_db_type_from_pd_type(dtypes.get(primary_key, 'int') or 'int', flavor=flavor) 2043 ) 2044 ] 2045 if primary_key 2046 else [] 2047 ) + [ 2048 (col, get_db_type_from_pd_type(typ, flavor=flavor)) 2049 for col, typ in dtypes.items() 2050 if col != primary_key 2051 ] 2052 2053 table_name = sql_item_name(new_table, schema=schema, flavor=flavor) 2054 primary_key_name = sql_item_name(primary_key, flavor) if primary_key else None 2055 primary_key_constraint_name = ( 2056 sql_item_name(f'PK_{new_table}', flavor, None) 2057 if primary_key 2058 else None 2059 ) 2060 datetime_column_name = sql_item_name(datetime_column, flavor) if datetime_column else None 2061 primary_key_clustered = ( 2062 "CLUSTERED" 2063 if not datetime_column or datetime_column == primary_key 2064 else "NONCLUSTERED" 2065 ) 2066 query = f"CREATE TABLE {table_name} (" 2067 if primary_key: 2068 col_db_type = cols_types[0][1] 2069 auto_increment_str = (' ' + AUTO_INCREMENT_COLUMN_FLAVORS.get( 2070 flavor, 2071 AUTO_INCREMENT_COLUMN_FLAVORS['default'] 2072 )) if autoincrement or primary_key not in dtypes else '' 2073 col_name = sql_item_name(primary_key, flavor=flavor, schema=None) 2074 2075 if flavor == 'sqlite': 2076 query += ( 2077 f"\n {col_name} " 2078 + (f"{col_db_type}" if not auto_increment_str else 'INTEGER') 2079 + f" PRIMARY KEY{auto_increment_str} NOT NULL," 2080 ) 2081 elif flavor == 'oracle': 2082 query += f"\n {col_name} {col_db_type} {auto_increment_str} PRIMARY KEY," 2083 elif flavor == 'timescaledb' and datetime_column and datetime_column != primary_key: 2084 query += f"\n {col_name} {col_db_type}{auto_increment_str} NOT NULL," 2085 elif flavor == 'mssql': 2086 query += f"\n {col_name} {col_db_type}{auto_increment_str} NOT NULL," 2087 else: 2088 query += f"\n {col_name} {col_db_type} PRIMARY KEY{auto_increment_str} NOT NULL," 2089 2090 for col, db_type in cols_types: 2091 if col == primary_key: 2092 continue 2093 col_name = sql_item_name(col, schema=None, flavor=flavor) 2094 query += f"\n {col_name} {db_type}," 2095 if ( 2096 flavor == 'timescaledb' 2097 and datetime_column 2098 and primary_key 2099 and datetime_column != primary_key 2100 ): 2101 query += f"\n PRIMARY KEY({datetime_column_name}, {primary_key_name})," 2102 2103 if flavor == 'mssql' and primary_key: 2104 query += f"\n CONSTRAINT {primary_key_constraint_name} PRIMARY KEY {primary_key_clustered} ({primary_key_name})," 2105 2106 query = query[:-1] 2107 query += "\n)" 2108 2109 queries = [query] 2110 return queries 2111 2112 2113def _get_create_table_query_from_cte( 2114 query: str, 2115 new_table: str, 2116 flavor: str, 2117 schema: Optional[str] = None, 2118 primary_key: Optional[str] = None, 2119 primary_key_db_type: Optional[str] = None, 2120 autoincrement: bool = False, 2121 datetime_column: Optional[str] = None, 2122) -> List[str]: 2123 """ 2124 Create a new table from a CTE query. 2125 """ 2126 import textwrap 2127 create_cte = 'create_query' 2128 create_cte_name = sql_item_name(create_cte, flavor, None) 2129 new_table_name = sql_item_name(new_table, flavor, schema) 2130 primary_key_constraint_name = ( 2131 sql_item_name(f'PK_{new_table}', flavor, None) 2132 if primary_key 2133 else None 2134 ) 2135 primary_key_name = ( 2136 sql_item_name(primary_key, flavor, None) 2137 if primary_key 2138 else None 2139 ) 2140 primary_key_clustered = ( 2141 "CLUSTERED" 2142 if not datetime_column or datetime_column == primary_key 2143 else "NONCLUSTERED" 2144 ) 2145 datetime_column_name = ( 2146 sql_item_name(datetime_column, flavor) 2147 if datetime_column 2148 else None 2149 ) 2150 if flavor in ('mssql',): 2151 query = query.lstrip() 2152 if query.lower().startswith('with '): 2153 final_select_ix = query.lower().rfind('select') 2154 create_table_queries = [ 2155 ( 2156 query[:final_select_ix].rstrip() + ',\n' 2157 + f"{create_cte_name} AS (\n" 2158 + textwrap.indent(query[final_select_ix:], ' ') 2159 + "\n)\n" 2160 + f"SELECT *\nINTO {new_table_name}\nFROM {create_cte_name}" 2161 ), 2162 ] 2163 else: 2164 create_table_queries = [ 2165 ( 2166 "SELECT *\n" 2167 f"INTO {new_table_name}\n" 2168 f"FROM (\n{textwrap.indent(query, ' ')}\n) AS {create_cte_name}" 2169 ), 2170 ] 2171 2172 alter_type_queries = [] 2173 if primary_key_db_type: 2174 alter_type_queries.extend([ 2175 ( 2176 f"ALTER TABLE {new_table_name}\n" 2177 f"ALTER COLUMN {primary_key_name} {primary_key_db_type} NOT NULL" 2178 ), 2179 ]) 2180 alter_type_queries.extend([ 2181 ( 2182 f"ALTER TABLE {new_table_name}\n" 2183 f"ADD CONSTRAINT {primary_key_constraint_name} " 2184 f"PRIMARY KEY {primary_key_clustered} ({primary_key_name})" 2185 ), 2186 ]) 2187 elif flavor in (None,): 2188 create_table_queries = [ 2189 ( 2190 f"WITH {create_cte_name} AS (\n{textwrap.index(query, ' ')}\n)\n" 2191 f"CREATE TABLE {new_table_name} AS\n" 2192 "SELECT *\n" 2193 f"FROM {create_cte_name}" 2194 ), 2195 ] 2196 2197 alter_type_queries = [ 2198 ( 2199 f"ALTER TABLE {new_table_name}\n" 2200 f"ADD PRIMARY KEY ({primary_key_name})" 2201 ), 2202 ] 2203 elif flavor in ('sqlite', 'mysql', 'mariadb', 'duckdb', 'oracle'): 2204 create_table_queries = [ 2205 ( 2206 f"CREATE TABLE {new_table_name} AS\n" 2207 "SELECT *\n" 2208 f"FROM (\n{textwrap.indent(query, ' ')}\n)" 2209 + (f" AS {create_cte_name}" if flavor != 'oracle' else '') 2210 ), 2211 ] 2212 2213 alter_type_queries = [ 2214 ( 2215 f"ALTER TABLE {new_table_name}\n" 2216 "ADD PRIMARY KEY ({primary_key_name})" 2217 ), 2218 ] 2219 elif flavor == 'timescaledb' and datetime_column and datetime_column != primary_key: 2220 create_table_queries = [ 2221 ( 2222 "SELECT *\n" 2223 f"INTO {new_table_name}\n" 2224 f"FROM (\n{textwrap.indent(query, ' ')}\n) AS {create_cte_name}\n" 2225 ), 2226 ] 2227 2228 alter_type_queries = [ 2229 ( 2230 f"ALTER TABLE {new_table_name}\n" 2231 f"ADD PRIMARY KEY ({datetime_column_name}, {primary_key_name})" 2232 ), 2233 ] 2234 else: 2235 create_table_queries = [ 2236 ( 2237 "SELECT *\n" 2238 f"INTO {new_table_name}\n" 2239 f"FROM (\n{textwrap.indent(query, ' ')}\n) AS {create_cte_name}" 2240 ), 2241 ] 2242 2243 alter_type_queries = [ 2244 ( 2245 f"ALTER TABLE {new_table_name}\n" 2246 f"ADD PRIMARY KEY ({primary_key_name})" 2247 ), 2248 ] 2249 2250 if not primary_key: 2251 return create_table_queries 2252 2253 return create_table_queries + alter_type_queries 2254 2255 2256def wrap_query_with_cte( 2257 sub_query: str, 2258 parent_query: str, 2259 flavor: str, 2260 cte_name: str = "src", 2261) -> str: 2262 """ 2263 Wrap a subquery in a CTE and append an encapsulating query. 2264 2265 Parameters 2266 ---------- 2267 sub_query: str 2268 The query to be referenced. This may itself contain CTEs. 2269 Unless `cte_name` is provided, this will be aliased as `src`. 2270 2271 parent_query: str 2272 The larger query to append which references the subquery. 2273 This must not contain CTEs. 2274 2275 flavor: str 2276 The database flavor, e.g. `'mssql'`. 2277 2278 cte_name: str, default 'src' 2279 The CTE alias, defaults to `src`. 2280 2281 Returns 2282 ------- 2283 An encapsulating query which allows you to treat `sub_query` as a temporary table. 2284 2285 Examples 2286 -------- 2287 2288 ```python 2289 from meerschaum.utils.sql import wrap_query_with_cte 2290 sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo" 2291 parent_query = "SELECT newval * 3 FROM src" 2292 query = wrap_query_with_cte(sub_query, parent_query, 'mssql') 2293 print(query) 2294 # WITH foo AS (SELECT 1 AS val), 2295 # [src] AS ( 2296 # SELECT (val * 2) AS newval FROM foo 2297 # ) 2298 # SELECT newval * 3 FROM src 2299 ``` 2300 2301 """ 2302 import textwrap 2303 sub_query = sub_query.lstrip() 2304 cte_name_quoted = sql_item_name(cte_name, flavor, None) 2305 2306 if flavor in NO_CTE_FLAVORS: 2307 return ( 2308 parent_query 2309 .replace(cte_name_quoted, '--MRSM_SUBQUERY--') 2310 .replace(cte_name, '--MRSM_SUBQUERY--') 2311 .replace('--MRSM_SUBQUERY--', f"(\n{sub_query}\n) AS {cte_name_quoted}") 2312 ) 2313 2314 if sub_query.lstrip().lower().startswith('with '): 2315 final_select_ix = sub_query.lower().rfind('select') 2316 return ( 2317 sub_query[:final_select_ix].rstrip() + ',\n' 2318 + f"{cte_name_quoted} AS (\n" 2319 + ' ' + sub_query[final_select_ix:] 2320 + "\n)\n" 2321 + parent_query 2322 ) 2323 2324 return ( 2325 f"WITH {cte_name_quoted} AS (\n" 2326 f"{textwrap.indent(sub_query, ' ')}\n" 2327 f")\n{parent_query}" 2328 ) 2329 2330 2331def format_cte_subquery( 2332 sub_query: str, 2333 flavor: str, 2334 sub_name: str = 'src', 2335 cols_to_select: Union[List[str], str] = '*', 2336) -> str: 2337 """ 2338 Given a subquery, build a wrapper query that selects from the CTE subquery. 2339 2340 Parameters 2341 ---------- 2342 sub_query: str 2343 The subquery to wrap. 2344 2345 flavor: str 2346 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 2347 2348 sub_name: str, default 'src' 2349 If possible, give this name to the CTE (must be unquoted). 2350 2351 cols_to_select: Union[List[str], str], default '' 2352 If specified, choose which columns to select from the CTE. 2353 If a list of strings is provided, each item will be quoted and joined with commas. 2354 If a string is given, assume it is quoted and insert it into the query. 2355 2356 Returns 2357 ------- 2358 A wrapper query that selects from the CTE. 2359 """ 2360 quoted_sub_name = sql_item_name(sub_name, flavor, None) 2361 cols_str = ( 2362 cols_to_select 2363 if isinstance(cols_to_select, str) 2364 else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select]) 2365 ) 2366 parent_query = ( 2367 f"SELECT {cols_str}\n" 2368 f"FROM {quoted_sub_name}" 2369 ) 2370 return wrap_query_with_cte(sub_query, parent_query, flavor, cte_name=sub_name) 2371 2372 2373def session_execute( 2374 session: 'sqlalchemy.orm.session.Session', 2375 queries: Union[List[str], str], 2376 with_results: bool = False, 2377 debug: bool = False, 2378) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]: 2379 """ 2380 Similar to `SQLConnector.exec_queries()`, execute a list of queries 2381 and roll back when one fails. 2382 2383 Parameters 2384 ---------- 2385 session: sqlalchemy.orm.session.Session 2386 A SQLAlchemy session representing a transaction. 2387 2388 queries: Union[List[str], str] 2389 A query or list of queries to be executed. 2390 If a query fails, roll back the session. 2391 2392 with_results: bool, default False 2393 If `True`, return a list of result objects. 2394 2395 Returns 2396 ------- 2397 A `SuccessTuple` indicating the queries were successfully executed. 2398 If `with_results`, return the `SuccessTuple` and a list of results. 2399 """ 2400 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 2401 if not isinstance(queries, list): 2402 queries = [queries] 2403 successes, msgs, results = [], [], [] 2404 for query in queries: 2405 if debug: 2406 dprint(query) 2407 query_text = sqlalchemy.text(query) 2408 fail_msg = "Failed to execute queries." 2409 try: 2410 result = session.execute(query_text) 2411 query_success = result is not None 2412 query_msg = "Success" if query_success else fail_msg 2413 except Exception as e: 2414 query_success = False 2415 query_msg = f"{fail_msg}\n{e}" 2416 result = None 2417 successes.append(query_success) 2418 msgs.append(query_msg) 2419 results.append(result) 2420 if not query_success: 2421 if debug: 2422 dprint("Rolling back session.") 2423 session.rollback() 2424 break 2425 success, msg = all(successes), '\n'.join(msgs) 2426 if with_results: 2427 return (success, msg), results 2428 return success, msg 2429 2430 2431def get_reset_autoincrement_queries( 2432 table: str, 2433 column: str, 2434 connector: mrsm.connectors.SQLConnector, 2435 schema: Optional[str] = None, 2436 debug: bool = False, 2437) -> List[str]: 2438 """ 2439 Return a list of queries to reset a table's auto-increment counter to the next largest value. 2440 2441 Parameters 2442 ---------- 2443 table: str 2444 The name of the table on which the auto-incrementing column exists. 2445 2446 column: str 2447 The name of the auto-incrementing column. 2448 2449 connector: mrsm.connectors.SQLConnector 2450 The SQLConnector to the database on which the table exists. 2451 2452 schema: Optional[str], default None 2453 The schema of the table. Defaults to `connector.schema`. 2454 2455 Returns 2456 ------- 2457 A list of queries to be executed to reset the auto-incrementing column. 2458 """ 2459 if not table_exists(table, connector, schema=schema, debug=debug): 2460 return [] 2461 2462 schema = schema or connector.schema 2463 max_id_name = sql_item_name('max_id', connector.flavor) 2464 table_name = sql_item_name(table, connector.flavor, schema) 2465 table_seq_name = sql_item_name(table + '_' + column + '_seq', connector.flavor, schema) 2466 column_name = sql_item_name(column, connector.flavor) 2467 max_id = connector.value( 2468 f""" 2469 SELECT COALESCE(MAX({column_name}), 0) AS {max_id_name} 2470 FROM {table_name} 2471 """, 2472 debug=debug, 2473 ) 2474 if max_id is None: 2475 return [] 2476 2477 reset_queries = reset_autoincrement_queries.get( 2478 connector.flavor, 2479 reset_autoincrement_queries['default'] 2480 ) 2481 if not isinstance(reset_queries, list): 2482 reset_queries = [reset_queries] 2483 2484 return [ 2485 query.format( 2486 column=column, 2487 column_name=column_name, 2488 table=table, 2489 table_name=table_name, 2490 table_seq_name=table_seq_name, 2491 val=max_id, 2492 val_plus_1=(max_id + 1), 2493 ) 2494 for query in reset_queries 2495 ]
523def clean(substring: str) -> str: 524 """ 525 Ensure a substring is clean enough to be inserted into a SQL query. 526 Raises an exception when banned words are used. 527 """ 528 from meerschaum.utils.warnings import error 529 banned_symbols = [';', '--', 'drop ',] 530 for symbol in banned_symbols: 531 if symbol in str(substring).lower(): 532 error(f"Invalid string: '{substring}'")
Ensure a substring is clean enough to be inserted into a SQL query. Raises an exception when banned words are used.
535def dateadd_str( 536 flavor: str = 'postgresql', 537 datepart: str = 'day', 538 number: Union[int, float] = 0, 539 begin: Union[str, datetime, int] = 'now', 540 db_type: Optional[str] = None, 541) -> str: 542 """ 543 Generate a `DATEADD` clause depending on database flavor. 544 545 Parameters 546 ---------- 547 flavor: str, default `'postgresql'` 548 SQL database flavor, e.g. `'postgresql'`, `'sqlite'`. 549 550 Currently supported flavors: 551 552 - `'postgresql'` 553 - `'timescaledb'` 554 - `'citus'` 555 - `'cockroachdb'` 556 - `'duckdb'` 557 - `'mssql'` 558 - `'mysql'` 559 - `'mariadb'` 560 - `'sqlite'` 561 - `'oracle'` 562 563 datepart: str, default `'day'` 564 Which part of the date to modify. Supported values: 565 566 - `'year'` 567 - `'month'` 568 - `'day'` 569 - `'hour'` 570 - `'minute'` 571 - `'second'` 572 573 number: Union[int, float], default `0` 574 How many units to add to the date part. 575 576 begin: Union[str, datetime], default `'now'` 577 Base datetime to which to add dateparts. 578 579 db_type: Optional[str], default None 580 If provided, cast the datetime string as the type. 581 Otherwise, infer this from the input datetime value. 582 583 Returns 584 ------- 585 The appropriate `DATEADD` string for the corresponding database flavor. 586 587 Examples 588 -------- 589 >>> dateadd_str( 590 ... flavor='mssql', 591 ... begin=datetime(2022, 1, 1, 0, 0), 592 ... number=1, 593 ... ) 594 "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))" 595 >>> dateadd_str( 596 ... flavor='postgresql', 597 ... begin=datetime(2022, 1, 1, 0, 0), 598 ... number=1, 599 ... ) 600 "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'" 601 602 """ 603 from meerschaum.utils.packages import attempt_import 604 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, get_pd_type_from_db_type 605 dateutil_parser = attempt_import('dateutil.parser') 606 if 'int' in str(type(begin)).lower(): 607 num_str = str(begin) 608 if number is not None and number != 0: 609 num_str += ( 610 f' + {number}' 611 if number > 0 612 else f" - {number * -1}" 613 ) 614 return num_str 615 if not begin: 616 return '' 617 618 _original_begin = begin 619 begin_time = None 620 ### Sanity check: make sure `begin` is a valid datetime before we inject anything. 621 if not isinstance(begin, datetime): 622 try: 623 begin_time = dateutil_parser.parse(begin) 624 except Exception: 625 begin_time = None 626 else: 627 begin_time = begin 628 629 ### Unable to parse into a datetime. 630 if begin_time is None: 631 ### Throw an error if banned symbols are included in the `begin` string. 632 clean(str(begin)) 633 ### If begin is a valid datetime, wrap it in quotes. 634 else: 635 if isinstance(begin, datetime) and begin.tzinfo is not None: 636 begin = begin.astimezone(timezone.utc) 637 begin = ( 638 f"'{begin.replace(tzinfo=None)}'" 639 if isinstance(begin, datetime) and flavor in TIMEZONE_NAIVE_FLAVORS 640 else f"'{begin}'" 641 ) 642 643 dt_is_utc = ( 644 begin_time.tzinfo is not None 645 if begin_time is not None 646 else ('+' in str(begin) or '-' in str(begin).split(':', maxsplit=1)[-1]) 647 ) 648 if db_type: 649 db_type_is_utc = 'utc' in get_pd_type_from_db_type(db_type).lower() 650 dt_is_utc = dt_is_utc or db_type_is_utc 651 db_type = db_type or get_db_type_from_pd_type( 652 ('datetime64[ns, UTC]' if dt_is_utc else 'datetime64[ns]'), 653 flavor=flavor, 654 ) 655 656 da = "" 657 if flavor in ('postgresql', 'timescaledb', 'cockroachdb', 'citus'): 658 begin = ( 659 f"CAST({begin} AS {db_type})" if begin != 'now' 660 else f"CAST(NOW() AT TIME ZONE 'utc' AS {db_type})" 661 ) 662 if dt_is_utc: 663 begin += " AT TIME ZONE 'UTC'" 664 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 665 666 elif flavor == 'duckdb': 667 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'NOW()' 668 if dt_is_utc: 669 begin += " AT TIME ZONE 'UTC'" 670 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 671 672 elif flavor in ('mssql',): 673 if begin_time and begin_time.microsecond != 0 and not dt_is_utc: 674 begin = begin[:-4] + "'" 675 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'GETUTCDATE()' 676 if dt_is_utc: 677 begin += " AT TIME ZONE 'UTC'" 678 da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin 679 680 elif flavor in ('mysql', 'mariadb'): 681 begin = ( 682 f"CAST({begin} AS DATETIME(6))" 683 if begin != 'now' 684 else 'UTC_TIMESTAMP(6)' 685 ) 686 da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin) 687 688 elif flavor == 'sqlite': 689 da = f"datetime({begin}, '{number} {datepart}')" 690 691 elif flavor == 'oracle': 692 if begin == 'now': 693 begin = str( 694 datetime.now(timezone.utc).replace(tzinfo=None).strftime(r'%Y:%m:%d %M:%S.%f') 695 ) 696 elif begin_time: 697 begin = str(begin_time.strftime(r'%Y-%m-%d %H:%M:%S.%f')) 698 dt_format = 'YYYY-MM-DD HH24:MI:SS.FF' 699 _begin = f"'{begin}'" if begin_time else begin 700 da = ( 701 (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin) 702 + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "") 703 ) 704 return da
Generate a DATEADD
clause depending on database flavor.
Parameters
flavor (str, default
'postgresql'
): SQL database flavor, e.g.'postgresql'
,'sqlite'
.Currently supported flavors:
'postgresql'
'timescaledb'
'citus'
'cockroachdb'
'duckdb'
'mssql'
'mysql'
'mariadb'
'sqlite'
'oracle'
datepart (str, default
'day'
): Which part of the date to modify. Supported values:'year'
'month'
'day'
'hour'
'minute'
'second'
- number (Union[int, float], default
0
): How many units to add to the date part. - begin (Union[str, datetime], default
'now'
): Base datetime to which to add dateparts. - db_type (Optional[str], default None): If provided, cast the datetime string as the type. Otherwise, infer this from the input datetime value.
Returns
- The appropriate
DATEADD
string for the corresponding database flavor.
Examples
>>> dateadd_str(
... flavor='mssql',
... begin=datetime(2022, 1, 1, 0, 0),
... number=1,
... )
"DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))"
>>> dateadd_str(
... flavor='postgresql',
... begin=datetime(2022, 1, 1, 0, 0),
... number=1,
... )
"CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'"
707def test_connection( 708 self, 709 **kw: Any 710) -> Union[bool, None]: 711 """ 712 Test if a successful connection to the database may be made. 713 714 Parameters 715 ---------- 716 **kw: 717 The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`. 718 719 Returns 720 ------- 721 `True` if a connection is made, otherwise `False` or `None` in case of failure. 722 723 """ 724 import warnings 725 from meerschaum.connectors.poll import retry_connect 726 _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self} 727 _default_kw.update(kw) 728 with warnings.catch_warnings(): 729 warnings.filterwarnings('ignore', 'Could not') 730 try: 731 return retry_connect(**_default_kw) 732 except Exception: 733 return False
Test if a successful connection to the database may be made.
Parameters
- **kw:: The keyword arguments are passed to
meerschaum.connectors.poll.retry_connect
.
Returns
True
if a connection is made, otherwiseFalse
orNone
in case of failure.
736def get_distinct_col_count( 737 col: str, 738 query: str, 739 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 740 debug: bool = False 741) -> Optional[int]: 742 """ 743 Returns the number of distinct items in a column of a SQL query. 744 745 Parameters 746 ---------- 747 col: str: 748 The column in the query to count. 749 750 query: str: 751 The SQL query to count from. 752 753 connector: Optional[mrsm.connectors.sql.SQLConnector], default None: 754 The SQLConnector to execute the query. 755 756 debug: bool, default False: 757 Verbosity toggle. 758 759 Returns 760 ------- 761 An `int` of the number of columns in the query or `None` if the query fails. 762 763 """ 764 if connector is None: 765 connector = mrsm.get_connector('sql') 766 767 _col_name = sql_item_name(col, connector.flavor, None) 768 769 _meta_query = ( 770 f""" 771 WITH src AS ( {query} ), 772 dist AS ( SELECT DISTINCT {_col_name} FROM src ) 773 SELECT COUNT(*) FROM dist""" 774 ) if connector.flavor not in ('mysql', 'mariadb') else ( 775 f""" 776 SELECT COUNT(*) 777 FROM ( 778 SELECT DISTINCT {_col_name} 779 FROM ({query}) AS src 780 ) AS dist""" 781 ) 782 783 result = connector.value(_meta_query, debug=debug) 784 try: 785 return int(result) 786 except Exception: 787 return None
Returns the number of distinct items in a column of a SQL query.
Parameters
- col (str:): The column in the query to count.
- query (str:): The SQL query to count from.
- connector (Optional[mrsm.connectors.sql.SQLConnector], default None:): The SQLConnector to execute the query.
- debug (bool, default False:): Verbosity toggle.
Returns
- An
int
of the number of columns in the query orNone
if the query fails.
790def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str: 791 """ 792 Parse SQL items depending on the flavor. 793 794 Parameters 795 ---------- 796 item: str 797 The database item (table, view, etc.) in need of quotes. 798 799 flavor: str 800 The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.). 801 802 schema: Optional[str], default None 803 If provided, prefix the table name with the schema. 804 805 Returns 806 ------- 807 A `str` which contains the input `item` wrapped in the corresponding escape characters. 808 809 Examples 810 -------- 811 >>> sql_item_name('table', 'sqlite') 812 '"table"' 813 >>> sql_item_name('table', 'mssql') 814 "[table]" 815 >>> sql_item_name('table', 'postgresql', schema='abc') 816 '"abc"."table"' 817 818 """ 819 truncated_item = truncate_item_name(str(item), flavor) 820 if flavor == 'oracle': 821 truncated_item = pg_capital(truncated_item, quote_capitals=True) 822 ### NOTE: System-reserved words must be quoted. 823 if truncated_item.lower() in ( 824 'float', 'varchar', 'nvarchar', 'clob', 825 'boolean', 'integer', 'table', 'row', 826 ): 827 wrappers = ('"', '"') 828 else: 829 wrappers = ('', '') 830 else: 831 wrappers = table_wrappers.get(flavor, table_wrappers['default']) 832 833 ### NOTE: SQLite does not support schemas. 834 if flavor == 'sqlite': 835 schema = None 836 elif flavor == 'mssql' and str(item).startswith('#'): 837 schema = None 838 839 schema_prefix = ( 840 (wrappers[0] + schema + wrappers[1] + '.') 841 if schema is not None 842 else '' 843 ) 844 845 return schema_prefix + wrappers[0] + truncated_item + wrappers[1]
Parse SQL items depending on the flavor.
Parameters
- item (str): The database item (table, view, etc.) in need of quotes.
- flavor (str):
The database flavor (
'postgresql'
,'mssql'
,'sqllite'
, etc.). - schema (Optional[str], default None): If provided, prefix the table name with the schema.
Returns
- A
str
which contains the inputitem
wrapped in the corresponding escape characters.
Examples
>>> sql_item_name('table', 'sqlite')
'"table"'
>>> sql_item_name('table', 'mssql')
"[table]"
>>> sql_item_name('table', 'postgresql', schema='abc')
'"abc"."table"'
848def pg_capital(s: str, quote_capitals: bool = True) -> str: 849 """ 850 If string contains a capital letter, wrap it in double quotes. 851 852 Parameters 853 ---------- 854 s: str 855 The string to be escaped. 856 857 quote_capitals: bool, default True 858 If `False`, do not quote strings with contain only a mix of capital and lower-case letters. 859 860 Returns 861 ------- 862 The input string wrapped in quotes only if it needs them. 863 864 Examples 865 -------- 866 >>> pg_capital("My Table") 867 '"My Table"' 868 >>> pg_capital('my_table') 869 'my_table' 870 871 """ 872 if s.startswith('"') and s.endswith('"'): 873 return s 874 875 s = s.replace('"', '') 876 877 needs_quotes = s.startswith('_') 878 if not needs_quotes: 879 for c in s: 880 if c == '_': 881 continue 882 883 if not c.isalnum() or (quote_capitals and c.isupper()): 884 needs_quotes = True 885 break 886 887 if needs_quotes: 888 return '"' + s + '"' 889 890 return s
If string contains a capital letter, wrap it in double quotes.
Parameters
- s (str): The string to be escaped.
- quote_capitals (bool, default True):
If
False
, do not quote strings with contain only a mix of capital and lower-case letters.
Returns
- The input string wrapped in quotes only if it needs them.
Examples
>>> pg_capital("My Table")
'"My Table"'
>>> pg_capital('my_table')
'my_table'
893def oracle_capital(s: str) -> str: 894 """ 895 Capitalize the string of an item on an Oracle database. 896 """ 897 return s
Capitalize the string of an item on an Oracle database.
900def truncate_item_name(item: str, flavor: str) -> str: 901 """ 902 Truncate item names to stay within the database flavor's character limit. 903 904 Parameters 905 ---------- 906 item: str 907 The database item being referenced. This string is the "canonical" name internally. 908 909 flavor: str 910 The flavor of the database on which `item` resides. 911 912 Returns 913 ------- 914 The truncated string. 915 """ 916 from meerschaum.utils.misc import truncate_string_sections 917 return truncate_string_sections( 918 item, max_len=max_name_lens.get(flavor, max_name_lens['default']) 919 )
Truncate item names to stay within the database flavor's character limit.
Parameters
- item (str): The database item being referenced. This string is the "canonical" name internally.
- flavor (str):
The flavor of the database on which
item
resides.
Returns
- The truncated string.
922def build_where( 923 params: Dict[str, Any], 924 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 925 with_where: bool = True, 926) -> str: 927 """ 928 Build the `WHERE` clause based on the input criteria. 929 930 Parameters 931 ---------- 932 params: Dict[str, Any]: 933 The keywords dictionary to convert into a WHERE clause. 934 If a value is a string which begins with an underscore, negate that value 935 (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`). 936 A value of `_None` will be interpreted as `IS NOT NULL`. 937 938 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 939 The Meerschaum SQLConnector that will be executing the query. 940 The connector is used to extract the SQL dialect. 941 942 with_where: bool, default True: 943 If `True`, include the leading `'WHERE'` string. 944 945 Returns 946 ------- 947 A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor. 948 949 Examples 950 -------- 951 ``` 952 >>> print(build_where({'foo': [1, 2, 3]})) 953 954 WHERE 955 "foo" IN ('1', '2', '3') 956 ``` 957 """ 958 import json 959 from meerschaum.config.static import STATIC_CONFIG 960 from meerschaum.utils.warnings import warn 961 from meerschaum.utils.dtypes import value_is_null, none_if_null 962 negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix'] 963 try: 964 params_json = json.dumps(params) 965 except Exception as e: 966 params_json = str(params) 967 bad_words = ['drop ', '--', ';'] 968 for word in bad_words: 969 if word in params_json.lower(): 970 warn(f"Aborting build_where() due to possible SQL injection.") 971 return '' 972 973 if connector is None: 974 from meerschaum import get_connector 975 connector = get_connector('sql') 976 where = "" 977 leading_and = "\n AND " 978 for key, value in params.items(): 979 _key = sql_item_name(key, connector.flavor, None) 980 ### search across a list (i.e. IN syntax) 981 if isinstance(value, Iterable) and not isinstance(value, (dict, str)): 982 includes = [ 983 none_if_null(item) 984 for item in value 985 if not str(item).startswith(negation_prefix) 986 ] 987 null_includes = [item for item in includes if item is None] 988 not_null_includes = [item for item in includes if item is not None] 989 excludes = [ 990 none_if_null(str(item)[len(negation_prefix):]) 991 for item in value 992 if str(item).startswith(negation_prefix) 993 ] 994 null_excludes = [item for item in excludes if item is None] 995 not_null_excludes = [item for item in excludes if item is not None] 996 997 if includes: 998 where += f"{leading_and}(" 999 if not_null_includes: 1000 where += f"{_key} IN (" 1001 for item in not_null_includes: 1002 quoted_item = str(item).replace("'", "''") 1003 where += f"'{quoted_item}', " 1004 where = where[:-2] + ")" 1005 if null_includes: 1006 where += ("\n OR " if not_null_includes else "") + f"{_key} IS NULL" 1007 if includes: 1008 where += ")" 1009 1010 if excludes: 1011 where += f"{leading_and}(" 1012 if not_null_excludes: 1013 where += f"{_key} NOT IN (" 1014 for item in not_null_excludes: 1015 quoted_item = str(item).replace("'", "''") 1016 where += f"'{quoted_item}', " 1017 where = where[:-2] + ")" 1018 if null_excludes: 1019 where += ("\n AND " if not_null_excludes else "") + f"{_key} IS NOT NULL" 1020 if excludes: 1021 where += ")" 1022 1023 continue 1024 1025 ### search a dictionary 1026 elif isinstance(value, dict): 1027 import json 1028 where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'") 1029 continue 1030 1031 eq_sign = '=' 1032 is_null = 'IS NULL' 1033 if value_is_null(str(value).lstrip(negation_prefix)): 1034 value = ( 1035 (negation_prefix + 'None') 1036 if str(value).startswith(negation_prefix) 1037 else None 1038 ) 1039 if str(value).startswith(negation_prefix): 1040 value = str(value)[len(negation_prefix):] 1041 eq_sign = '!=' 1042 if value_is_null(value): 1043 value = None 1044 is_null = 'IS NOT NULL' 1045 quoted_value = str(value).replace("'", "''") 1046 where += ( 1047 f"{leading_and}{_key} " 1048 + (is_null if value is None else f"{eq_sign} '{quoted_value}'") 1049 ) 1050 1051 if len(where) > 1: 1052 where = ("\nWHERE\n " if with_where else '') + where[len(leading_and):] 1053 return where
Build the WHERE
clause based on the input criteria.
Parameters
- params (Dict[str, Any]:):
The keywords dictionary to convert into a WHERE clause.
If a value is a string which begins with an underscore, negate that value
(e.g.
!=
instead of=
orNOT IN
instead ofIN
). A value of_None
will be interpreted asIS NOT NULL
. - connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The Meerschaum SQLConnector that will be executing the query. The connector is used to extract the SQL dialect.
- with_where (bool, default True:):
If
True
, include the leading'WHERE'
string.
Returns
- A
str
of theWHERE
clause from the inputparams
dictionary for the connector's flavor.
Examples
>>> print(build_where({'foo': [1, 2, 3]}))
WHERE
"foo" IN ('1', '2', '3')
1056def table_exists( 1057 table: str, 1058 connector: mrsm.connectors.sql.SQLConnector, 1059 schema: Optional[str] = None, 1060 debug: bool = False, 1061) -> bool: 1062 """Check if a table exists. 1063 1064 Parameters 1065 ---------- 1066 table: str: 1067 The name of the table in question. 1068 1069 connector: mrsm.connectors.sql.SQLConnector 1070 The connector to the database which holds the table. 1071 1072 schema: Optional[str], default None 1073 Optionally specify the table schema. 1074 Defaults to `connector.schema`. 1075 1076 debug: bool, default False : 1077 Verbosity toggle. 1078 1079 Returns 1080 ------- 1081 A `bool` indicating whether or not the table exists on the database. 1082 """ 1083 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1084 schema = schema or connector.schema 1085 insp = sqlalchemy.inspect(connector.engine) 1086 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1087 return insp.has_table(truncated_table_name, schema=schema)
Check if a table exists.
Parameters
- table (str:): The name of the table in question.
- connector (mrsm.connectors.sql.SQLConnector): The connector to the database which holds the table.
- schema (Optional[str], default None):
Optionally specify the table schema.
Defaults to
connector.schema
. - debug (bool, default False :): Verbosity toggle.
Returns
- A
bool
indicating whether or not the table exists on the database.
1090def get_sqlalchemy_table( 1091 table: str, 1092 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 1093 schema: Optional[str] = None, 1094 refresh: bool = False, 1095 debug: bool = False, 1096) -> Union['sqlalchemy.Table', None]: 1097 """ 1098 Construct a SQLAlchemy table from its name. 1099 1100 Parameters 1101 ---------- 1102 table: str 1103 The name of the table on the database. Does not need to be escaped. 1104 1105 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 1106 The connector to the database which holds the table. 1107 1108 schema: Optional[str], default None 1109 Specify on which schema the table resides. 1110 Defaults to the schema set in `connector`. 1111 1112 refresh: bool, default False 1113 If `True`, rebuild the cached table object. 1114 1115 debug: bool, default False: 1116 Verbosity toggle. 1117 1118 Returns 1119 ------- 1120 A `sqlalchemy.Table` object for the table. 1121 1122 """ 1123 if connector is None: 1124 from meerschaum import get_connector 1125 connector = get_connector('sql') 1126 1127 if connector.flavor == 'duckdb': 1128 return None 1129 1130 from meerschaum.connectors.sql.tables import get_tables 1131 from meerschaum.utils.packages import attempt_import 1132 from meerschaum.utils.warnings import warn 1133 if refresh: 1134 connector.metadata.clear() 1135 tables = get_tables(mrsm_instance=connector, debug=debug, create=False) 1136 sqlalchemy = attempt_import('sqlalchemy', lazy=False) 1137 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1138 table_kwargs = { 1139 'autoload_with': connector.engine, 1140 } 1141 if schema: 1142 table_kwargs['schema'] = schema 1143 1144 if refresh or truncated_table_name not in tables: 1145 try: 1146 tables[truncated_table_name] = sqlalchemy.Table( 1147 truncated_table_name, 1148 connector.metadata, 1149 **table_kwargs 1150 ) 1151 except sqlalchemy.exc.NoSuchTableError: 1152 warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.") 1153 return None 1154 return tables[truncated_table_name]
Construct a SQLAlchemy table from its name.
Parameters
- table (str): The name of the table on the database. Does not need to be escaped.
- connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The connector to the database which holds the table.
- schema (Optional[str], default None):
Specify on which schema the table resides.
Defaults to the schema set in
connector
. - refresh (bool, default False):
If
True
, rebuild the cached table object. - debug (bool, default False:): Verbosity toggle.
Returns
- A
sqlalchemy.Table
object for the table.
1157def get_table_cols_types( 1158 table: str, 1159 connectable: Union[ 1160 'mrsm.connectors.sql.SQLConnector', 1161 'sqlalchemy.orm.session.Session', 1162 'sqlalchemy.engine.base.Engine' 1163 ], 1164 flavor: Optional[str] = None, 1165 schema: Optional[str] = None, 1166 database: Optional[str] = None, 1167 debug: bool = False, 1168) -> Dict[str, str]: 1169 """ 1170 Return a dictionary mapping a table's columns to data types. 1171 This is useful for inspecting tables creating during a not-yet-committed session. 1172 1173 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1174 Use this function if you are confident the table name is unique or if you have 1175 and explicit schema. 1176 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1177 1178 Parameters 1179 ---------- 1180 table: str 1181 The name of the table (unquoted). 1182 1183 connectable: Union[ 1184 'mrsm.connectors.sql.SQLConnector', 1185 'sqlalchemy.orm.session.Session', 1186 'sqlalchemy.engine.base.Engine' 1187 ] 1188 The connection object used to fetch the columns and types. 1189 1190 flavor: Optional[str], default None 1191 The database dialect flavor to use for the query. 1192 If omitted, default to `connectable.flavor`. 1193 1194 schema: Optional[str], default None 1195 If provided, restrict the query to this schema. 1196 1197 database: Optional[str]. default None 1198 If provided, restrict the query to this database. 1199 1200 Returns 1201 ------- 1202 A dictionary mapping column names to data types. 1203 """ 1204 import textwrap 1205 from meerschaum.connectors import SQLConnector 1206 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1207 flavor = flavor or getattr(connectable, 'flavor', None) 1208 if not flavor: 1209 raise ValueError("Please provide a database flavor.") 1210 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1211 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1212 if flavor in NO_SCHEMA_FLAVORS: 1213 schema = None 1214 if schema is None: 1215 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1216 if flavor in ('sqlite', 'duckdb', 'oracle'): 1217 database = None 1218 table_trunc = truncate_item_name(table, flavor=flavor) 1219 table_lower = table.lower() 1220 table_upper = table.upper() 1221 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1222 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1223 db_prefix = ( 1224 "tempdb." 1225 if flavor == 'mssql' and table.startswith('#') 1226 else "" 1227 ) 1228 1229 cols_types_query = sqlalchemy.text( 1230 textwrap.dedent(columns_types_queries.get( 1231 flavor, 1232 columns_types_queries['default'] 1233 ).format( 1234 table=table, 1235 table_trunc=table_trunc, 1236 table_lower=table_lower, 1237 table_lower_trunc=table_lower_trunc, 1238 table_upper=table_upper, 1239 table_upper_trunc=table_upper_trunc, 1240 db_prefix=db_prefix, 1241 )).lstrip().rstrip() 1242 ) 1243 1244 cols = ['database', 'schema', 'table', 'column', 'type', 'numeric_precision', 'numeric_scale'] 1245 result_cols_ix = dict(enumerate(cols)) 1246 1247 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1248 if not debug_kwargs and debug: 1249 dprint(cols_types_query) 1250 1251 try: 1252 result_rows = ( 1253 [ 1254 row 1255 for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall() 1256 ] 1257 if flavor != 'duckdb' 1258 else [ 1259 tuple([doc[col] for col in cols]) 1260 for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records') 1261 ] 1262 ) 1263 cols_types_docs = [ 1264 { 1265 result_cols_ix[i]: val 1266 for i, val in enumerate(row) 1267 } 1268 for row in result_rows 1269 ] 1270 cols_types_docs_filtered = [ 1271 doc 1272 for doc in cols_types_docs 1273 if ( 1274 ( 1275 not schema 1276 or doc['schema'] == schema 1277 ) 1278 and 1279 ( 1280 not database 1281 or doc['database'] == database 1282 ) 1283 ) 1284 ] 1285 1286 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1287 if cols_types_docs and not cols_types_docs_filtered: 1288 cols_types_docs_filtered = cols_types_docs 1289 1290 return { 1291 ( 1292 doc['column'] 1293 if flavor != 'oracle' else ( 1294 ( 1295 doc['column'].lower() 1296 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1297 else doc['column'] 1298 ) 1299 ) 1300 ): doc['type'].upper() + ( 1301 f'({precision},{scale})' 1302 if ( 1303 (precision := doc.get('numeric_precision', None)) 1304 and 1305 (scale := doc.get('numeric_scale', None)) 1306 ) 1307 else '' 1308 ) 1309 for doc in cols_types_docs_filtered 1310 } 1311 except Exception as e: 1312 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1313 return {}
Return a dictionary mapping a table's columns to data types. This is useful for inspecting tables creating during a not-yet-committed session.
NOTE: This may return incorrect columns if the schema is not explicitly stated.
Use this function if you are confident the table name is unique or if you have
and explicit schema.
To use the configured schema, get the columns from get_sqlalchemy_table()
instead.
Parameters
- table (str): The name of the table (unquoted).
- connectable (Union[): 'mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine'
- ]: The connection object used to fetch the columns and types.
- flavor (Optional[str], default None):
The database dialect flavor to use for the query.
If omitted, default to
connectable.flavor
. - schema (Optional[str], default None): If provided, restrict the query to this schema.
- database (Optional[str]. default None): If provided, restrict the query to this database.
Returns
- A dictionary mapping column names to data types.
1316def get_table_cols_indices( 1317 table: str, 1318 connectable: Union[ 1319 'mrsm.connectors.sql.SQLConnector', 1320 'sqlalchemy.orm.session.Session', 1321 'sqlalchemy.engine.base.Engine' 1322 ], 1323 flavor: Optional[str] = None, 1324 schema: Optional[str] = None, 1325 database: Optional[str] = None, 1326 debug: bool = False, 1327) -> Dict[str, List[str]]: 1328 """ 1329 Return a dictionary mapping a table's columns to lists of indices. 1330 This is useful for inspecting tables creating during a not-yet-committed session. 1331 1332 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1333 Use this function if you are confident the table name is unique or if you have 1334 and explicit schema. 1335 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1336 1337 Parameters 1338 ---------- 1339 table: str 1340 The name of the table (unquoted). 1341 1342 connectable: Union[ 1343 'mrsm.connectors.sql.SQLConnector', 1344 'sqlalchemy.orm.session.Session', 1345 'sqlalchemy.engine.base.Engine' 1346 ] 1347 The connection object used to fetch the columns and types. 1348 1349 flavor: Optional[str], default None 1350 The database dialect flavor to use for the query. 1351 If omitted, default to `connectable.flavor`. 1352 1353 schema: Optional[str], default None 1354 If provided, restrict the query to this schema. 1355 1356 database: Optional[str]. default None 1357 If provided, restrict the query to this database. 1358 1359 Returns 1360 ------- 1361 A dictionary mapping column names to a list of indices. 1362 """ 1363 import textwrap 1364 from collections import defaultdict 1365 from meerschaum.connectors import SQLConnector 1366 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1367 flavor = flavor or getattr(connectable, 'flavor', None) 1368 if not flavor: 1369 raise ValueError("Please provide a database flavor.") 1370 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1371 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1372 if flavor in NO_SCHEMA_FLAVORS: 1373 schema = None 1374 if schema is None: 1375 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1376 if flavor in ('sqlite', 'duckdb', 'oracle'): 1377 database = None 1378 table_trunc = truncate_item_name(table, flavor=flavor) 1379 table_lower = table.lower() 1380 table_upper = table.upper() 1381 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1382 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1383 db_prefix = ( 1384 "tempdb." 1385 if flavor == 'mssql' and table.startswith('#') 1386 else "" 1387 ) 1388 1389 cols_indices_query = sqlalchemy.text( 1390 textwrap.dedent(columns_indices_queries.get( 1391 flavor, 1392 columns_indices_queries['default'] 1393 ).format( 1394 table=table, 1395 table_trunc=table_trunc, 1396 table_lower=table_lower, 1397 table_lower_trunc=table_lower_trunc, 1398 table_upper=table_upper, 1399 table_upper_trunc=table_upper_trunc, 1400 db_prefix=db_prefix, 1401 schema=schema, 1402 )).lstrip().rstrip() 1403 ) 1404 1405 cols = ['database', 'schema', 'table', 'column', 'index', 'index_type'] 1406 if flavor == 'mssql': 1407 cols.append('clustered') 1408 result_cols_ix = dict(enumerate(cols)) 1409 1410 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1411 if not debug_kwargs and debug: 1412 dprint(cols_indices_query) 1413 1414 try: 1415 result_rows = ( 1416 [ 1417 row 1418 for row in connectable.execute(cols_indices_query, **debug_kwargs).fetchall() 1419 ] 1420 if flavor != 'duckdb' 1421 else [ 1422 tuple([doc[col] for col in cols]) 1423 for doc in connectable.read(cols_indices_query, debug=debug).to_dict(orient='records') 1424 ] 1425 ) 1426 cols_types_docs = [ 1427 { 1428 result_cols_ix[i]: val 1429 for i, val in enumerate(row) 1430 } 1431 for row in result_rows 1432 ] 1433 cols_types_docs_filtered = [ 1434 doc 1435 for doc in cols_types_docs 1436 if ( 1437 ( 1438 not schema 1439 or doc['schema'] == schema 1440 ) 1441 and 1442 ( 1443 not database 1444 or doc['database'] == database 1445 ) 1446 ) 1447 ] 1448 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1449 if cols_types_docs and not cols_types_docs_filtered: 1450 cols_types_docs_filtered = cols_types_docs 1451 1452 cols_indices = defaultdict(lambda: []) 1453 for doc in cols_types_docs_filtered: 1454 col = ( 1455 doc['column'] 1456 if flavor != 'oracle' 1457 else ( 1458 doc['column'].lower() 1459 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1460 else doc['column'] 1461 ) 1462 ) 1463 index_doc = { 1464 'name': doc.get('index', None), 1465 'type': doc.get('index_type', None) 1466 } 1467 if flavor == 'mssql': 1468 index_doc['clustered'] = doc.get('clustered', None) 1469 cols_indices[col].append(index_doc) 1470 1471 return dict(cols_indices) 1472 except Exception as e: 1473 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1474 return {}
Return a dictionary mapping a table's columns to lists of indices. This is useful for inspecting tables creating during a not-yet-committed session.
NOTE: This may return incorrect columns if the schema is not explicitly stated.
Use this function if you are confident the table name is unique or if you have
and explicit schema.
To use the configured schema, get the columns from get_sqlalchemy_table()
instead.
Parameters
- table (str): The name of the table (unquoted).
- connectable (Union[): 'mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine'
- ]: The connection object used to fetch the columns and types.
- flavor (Optional[str], default None):
The database dialect flavor to use for the query.
If omitted, default to
connectable.flavor
. - schema (Optional[str], default None): If provided, restrict the query to this schema.
- database (Optional[str]. default None): If provided, restrict the query to this database.
Returns
- A dictionary mapping column names to a list of indices.
1477def get_update_queries( 1478 target: str, 1479 patch: str, 1480 connectable: Union[ 1481 mrsm.connectors.sql.SQLConnector, 1482 'sqlalchemy.orm.session.Session' 1483 ], 1484 join_cols: Iterable[str], 1485 flavor: Optional[str] = None, 1486 upsert: bool = False, 1487 datetime_col: Optional[str] = None, 1488 schema: Optional[str] = None, 1489 patch_schema: Optional[str] = None, 1490 identity_insert: bool = False, 1491 null_indices: bool = True, 1492 cast_columns: bool = True, 1493 debug: bool = False, 1494) -> List[str]: 1495 """ 1496 Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table. 1497 1498 Parameters 1499 ---------- 1500 target: str 1501 The name of the target table. 1502 1503 patch: str 1504 The name of the patch table. This should have the same shape as the target. 1505 1506 connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session] 1507 The `SQLConnector` or SQLAlchemy session which will later execute the queries. 1508 1509 join_cols: List[str] 1510 The columns to use to join the patch to the target. 1511 1512 flavor: Optional[str], default None 1513 If using a SQLAlchemy session, provide the expected database flavor. 1514 1515 upsert: bool, default False 1516 If `True`, return an upsert query rather than an update. 1517 1518 datetime_col: Optional[str], default None 1519 If provided, bound the join query using this column as the datetime index. 1520 This must be present on both tables. 1521 1522 schema: Optional[str], default None 1523 If provided, use this schema when quoting the target table. 1524 Defaults to `connector.schema`. 1525 1526 patch_schema: Optional[str], default None 1527 If provided, use this schema when quoting the patch table. 1528 Defaults to `schema`. 1529 1530 identity_insert: bool, default False 1531 If `True`, include `SET IDENTITY_INSERT` queries before and after the update queries. 1532 Only applies for MSSQL upserts. 1533 1534 null_indices: bool, default True 1535 If `False`, do not coalesce index columns before joining. 1536 1537 cast_columns: bool, default True 1538 If `False`, do not cast update columns to the target table types. 1539 1540 debug: bool, default False 1541 Verbosity toggle. 1542 1543 Returns 1544 ------- 1545 A list of query strings to perform the update operation. 1546 """ 1547 import textwrap 1548 from meerschaum.connectors import SQLConnector 1549 from meerschaum.utils.debug import dprint 1550 from meerschaum.utils.dtypes import are_dtypes_equal 1551 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES, get_pd_type_from_db_type 1552 flavor = flavor or (connectable.flavor if isinstance(connectable, SQLConnector) else None) 1553 if not flavor: 1554 raise ValueError("Provide a flavor if using a SQLAlchemy session.") 1555 if ( 1556 flavor == 'sqlite' 1557 and isinstance(connectable, SQLConnector) 1558 and connectable.db_version < '3.33.0' 1559 ): 1560 flavor = 'sqlite_delete_insert' 1561 flavor_key = (f'{flavor}-upsert' if upsert else flavor) 1562 base_queries = UPDATE_QUERIES.get( 1563 flavor_key, 1564 UPDATE_QUERIES['default'] 1565 ) 1566 if not isinstance(base_queries, list): 1567 base_queries = [base_queries] 1568 schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None) 1569 patch_schema = patch_schema or schema 1570 target_table_columns = get_table_cols_types( 1571 target, 1572 connectable, 1573 flavor=flavor, 1574 schema=schema, 1575 debug=debug, 1576 ) 1577 patch_table_columns = get_table_cols_types( 1578 patch, 1579 connectable, 1580 flavor=flavor, 1581 schema=patch_schema, 1582 debug=debug, 1583 ) 1584 1585 patch_cols_str = ', '.join( 1586 [ 1587 sql_item_name(col, flavor) 1588 for col in patch_table_columns 1589 ] 1590 ) 1591 patch_cols_prefixed_str = ', '.join( 1592 [ 1593 'p.' + sql_item_name(col, flavor) 1594 for col in patch_table_columns 1595 ] 1596 ) 1597 1598 join_cols_str = ', '.join( 1599 [ 1600 sql_item_name(col, flavor) 1601 for col in join_cols 1602 ] 1603 ) 1604 1605 value_cols = [] 1606 join_cols_types = [] 1607 if debug: 1608 dprint("target_table_columns:") 1609 mrsm.pprint(target_table_columns) 1610 for c_name, c_type in target_table_columns.items(): 1611 if c_name not in patch_table_columns: 1612 continue 1613 if flavor in DB_FLAVORS_CAST_DTYPES: 1614 c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type) 1615 ( 1616 join_cols_types 1617 if c_name in join_cols 1618 else value_cols 1619 ).append((c_name, c_type)) 1620 if debug: 1621 dprint(f"value_cols: {value_cols}") 1622 1623 if not join_cols_types: 1624 return [] 1625 if not value_cols and not upsert: 1626 return [] 1627 1628 coalesce_join_cols_str = ', '.join( 1629 [ 1630 ( 1631 ( 1632 'COALESCE(' 1633 + sql_item_name(c_name, flavor) 1634 + ', ' 1635 + get_null_replacement(c_type, flavor) 1636 + ')' 1637 ) 1638 if null_indices 1639 else sql_item_name(c_name, flavor) 1640 ) 1641 for c_name, c_type in join_cols_types 1642 ] 1643 ) 1644 1645 update_or_nothing = ('UPDATE' if value_cols else 'NOTHING') 1646 1647 def sets_subquery(l_prefix: str, r_prefix: str): 1648 if not value_cols: 1649 return '' 1650 1651 utc_value_cols = { 1652 c_name 1653 for c_name, c_type in value_cols 1654 if ('utc' in get_pd_type_from_db_type(c_type).lower()) 1655 } if flavor not in TIMEZONE_NAIVE_FLAVORS else set() 1656 1657 cast_func_cols = { 1658 c_name: ( 1659 ('', '', '') 1660 if not cast_columns or ( 1661 flavor == 'oracle' 1662 and are_dtypes_equal(get_pd_type_from_db_type(c_type), 'bytes') 1663 ) 1664 else ( 1665 ('CAST(', f" AS {c_type.replace('_', ' ')}", ')' + ( 1666 " AT TIME ZONE 'UTC'" 1667 if c_name in utc_value_cols 1668 else '' 1669 )) 1670 if flavor != 'sqlite' 1671 else ('', '', '') 1672 ) 1673 ) 1674 for c_name, c_type in value_cols 1675 } 1676 return 'SET ' + ',\n'.join([ 1677 ( 1678 l_prefix + sql_item_name(c_name, flavor, None) 1679 + ' = ' 1680 + cast_func_cols[c_name][0] 1681 + r_prefix + sql_item_name(c_name, flavor, None) 1682 + cast_func_cols[c_name][1] 1683 + cast_func_cols[c_name][2] 1684 ) for c_name, c_type in value_cols 1685 ]) 1686 1687 def and_subquery(l_prefix: str, r_prefix: str): 1688 return '\n AND\n '.join([ 1689 ( 1690 ( 1691 "COALESCE(" 1692 + l_prefix 1693 + sql_item_name(c_name, flavor, None) 1694 + ", " 1695 + get_null_replacement(c_type, flavor) 1696 + ")" 1697 + '\n =\n ' 1698 + "COALESCE(" 1699 + r_prefix 1700 + sql_item_name(c_name, flavor, None) 1701 + ", " 1702 + get_null_replacement(c_type, flavor) 1703 + ")" 1704 ) 1705 if null_indices 1706 else ( 1707 l_prefix 1708 + sql_item_name(c_name, flavor, None) 1709 + ' = ' 1710 + r_prefix 1711 + sql_item_name(c_name, flavor, None) 1712 ) 1713 ) for c_name, c_type in join_cols_types 1714 ]) 1715 1716 skip_query_val = "" 1717 target_table_name = sql_item_name(target, flavor, schema) 1718 patch_table_name = sql_item_name(patch, flavor, patch_schema) 1719 dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None 1720 date_bounds_table = patch_table_name if flavor != 'mssql' else '[date_bounds]' 1721 min_dt_col_name = f"MIN({dt_col_name})" if flavor != 'mssql' else '[Min_dt]' 1722 max_dt_col_name = f"MAX({dt_col_name})" if flavor != 'mssql' else '[Max_dt]' 1723 date_bounds_subquery = ( 1724 f"""f.{dt_col_name} >= (SELECT {min_dt_col_name} FROM {date_bounds_table}) 1725 AND 1726 f.{dt_col_name} <= (SELECT {max_dt_col_name} FROM {date_bounds_table})""" 1727 if datetime_col 1728 else "1 = 1" 1729 ) 1730 with_temp_date_bounds = f"""WITH [date_bounds] AS ( 1731 SELECT MIN({dt_col_name}) AS {min_dt_col_name}, MAX({dt_col_name}) AS {max_dt_col_name} 1732 FROM {patch_table_name} 1733 )""" if datetime_col else "" 1734 identity_insert_on = ( 1735 f"SET IDENTITY_INSERT {target_table_name} ON" 1736 if identity_insert 1737 else skip_query_val 1738 ) 1739 identity_insert_off = ( 1740 f"SET IDENTITY_INSERT {target_table_name} OFF" 1741 if identity_insert 1742 else skip_query_val 1743 ) 1744 1745 ### NOTE: MSSQL upserts must exclude the update portion if only upserting indices. 1746 when_matched_update_sets_subquery_none = "" if not value_cols else ( 1747 "\n WHEN MATCHED THEN\n" 1748 f" UPDATE {sets_subquery('', 'p.')}" 1749 ) 1750 1751 cols_equal_values = '\n,'.join( 1752 [ 1753 f"{sql_item_name(c_name, flavor)} = VALUES({sql_item_name(c_name, flavor)})" 1754 for c_name, c_type in value_cols 1755 ] 1756 ) 1757 on_duplicate_key_update = ( 1758 "ON DUPLICATE KEY UPDATE" 1759 if value_cols 1760 else "" 1761 ) 1762 ignore = "IGNORE " if not value_cols else "" 1763 1764 formatted_queries = [ 1765 textwrap.dedent(base_query.format( 1766 sets_subquery_none=sets_subquery('', 'p.'), 1767 sets_subquery_none_excluded=sets_subquery('', 'EXCLUDED.'), 1768 sets_subquery_f=sets_subquery('f.', 'p.'), 1769 and_subquery_f=and_subquery('p.', 'f.'), 1770 and_subquery_t=and_subquery('p.', 't.'), 1771 target_table_name=target_table_name, 1772 patch_table_name=patch_table_name, 1773 patch_cols_str=patch_cols_str, 1774 patch_cols_prefixed_str=patch_cols_prefixed_str, 1775 date_bounds_subquery=date_bounds_subquery, 1776 join_cols_str=join_cols_str, 1777 coalesce_join_cols_str=coalesce_join_cols_str, 1778 update_or_nothing=update_or_nothing, 1779 when_matched_update_sets_subquery_none=when_matched_update_sets_subquery_none, 1780 cols_equal_values=cols_equal_values, 1781 on_duplicate_key_update=on_duplicate_key_update, 1782 ignore=ignore, 1783 with_temp_date_bounds=with_temp_date_bounds, 1784 identity_insert_on=identity_insert_on, 1785 identity_insert_off=identity_insert_off, 1786 )).lstrip().rstrip() 1787 for base_query in base_queries 1788 ] 1789 1790 ### NOTE: Allow for skipping some queries. 1791 return [query for query in formatted_queries if query]
Build a list of MERGE
, UPDATE
, DELETE
/INSERT
queries to apply a patch to target table.
Parameters
- target (str): The name of the target table.
- patch (str): The name of the patch table. This should have the same shape as the target.
- connectable (Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session]):
The
SQLConnector
or SQLAlchemy session which will later execute the queries. - join_cols (List[str]): The columns to use to join the patch to the target.
- flavor (Optional[str], default None): If using a SQLAlchemy session, provide the expected database flavor.
- upsert (bool, default False):
If
True
, return an upsert query rather than an update. - datetime_col (Optional[str], default None): If provided, bound the join query using this column as the datetime index. This must be present on both tables.
- schema (Optional[str], default None):
If provided, use this schema when quoting the target table.
Defaults to
connector.schema
. - patch_schema (Optional[str], default None):
If provided, use this schema when quoting the patch table.
Defaults to
schema
. - identity_insert (bool, default False):
If
True
, includeSET IDENTITY_INSERT
queries before and after the update queries. Only applies for MSSQL upserts. - null_indices (bool, default True):
If
False
, do not coalesce index columns before joining. - cast_columns (bool, default True):
If
False
, do not cast update columns to the target table types. - debug (bool, default False): Verbosity toggle.
Returns
- A list of query strings to perform the update operation.
1794def get_null_replacement(typ: str, flavor: str) -> str: 1795 """ 1796 Return a value that may temporarily be used in place of NULL for this type. 1797 1798 Parameters 1799 ---------- 1800 typ: str 1801 The typ to be converted to NULL. 1802 1803 flavor: str 1804 The database flavor for which this value will be used. 1805 1806 Returns 1807 ------- 1808 A value which may stand in place of NULL for this type. 1809 `'None'` is returned if a value cannot be determined. 1810 """ 1811 from meerschaum.utils.dtypes import are_dtypes_equal 1812 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 1813 if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'): 1814 return '-987654321' 1815 if 'bool' in typ.lower() or typ.lower() == 'bit': 1816 bool_typ = ( 1817 PD_TO_DB_DTYPES_FLAVORS 1818 .get('bool', {}) 1819 .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default']) 1820 ) 1821 if flavor in DB_FLAVORS_CAST_DTYPES: 1822 bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ) 1823 val_to_cast = ( 1824 -987654321 1825 if flavor in ('mysql', 'mariadb') 1826 else 0 1827 ) 1828 return f'CAST({val_to_cast} AS {bool_typ})' 1829 if 'time' in typ.lower() or 'date' in typ.lower(): 1830 db_type = typ if typ.isupper() else None 1831 return dateadd_str(flavor=flavor, begin='1900-01-01', db_type=db_type) 1832 if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',): 1833 return '-987654321.0' 1834 if flavor == 'oracle' and typ.lower().split('(', maxsplit=1)[0] == 'char': 1835 return "'-987654321'" 1836 if flavor == 'oracle' and typ.lower() in ('blob', 'bytes'): 1837 return '00' 1838 if typ.lower() in ('uniqueidentifier', 'guid', 'uuid'): 1839 magic_val = 'DEADBEEF-ABBA-BABE-CAFE-DECAFC0FFEE5' 1840 if flavor == 'mssql': 1841 return f"CAST('{magic_val}' AS UNIQUEIDENTIFIER)" 1842 return f"'{magic_val}'" 1843 return ('n' if flavor == 'oracle' else '') + "'-987654321'"
Return a value that may temporarily be used in place of NULL for this type.
Parameters
- typ (str): The typ to be converted to NULL.
- flavor (str): The database flavor for which this value will be used.
Returns
- A value which may stand in place of NULL for this type.
'None'
is returned if a value cannot be determined.
1846def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]: 1847 """ 1848 Fetch the database version if possible. 1849 """ 1850 version_name = sql_item_name('version', conn.flavor, None) 1851 version_query = version_queries.get( 1852 conn.flavor, 1853 version_queries['default'] 1854 ).format(version_name=version_name) 1855 return conn.value(version_query, debug=debug)
Fetch the database version if possible.
1858def get_rename_table_queries( 1859 old_table: str, 1860 new_table: str, 1861 flavor: str, 1862 schema: Optional[str] = None, 1863) -> List[str]: 1864 """ 1865 Return queries to alter a table's name. 1866 1867 Parameters 1868 ---------- 1869 old_table: str 1870 The unquoted name of the old table. 1871 1872 new_table: str 1873 The unquoted name of the new table. 1874 1875 flavor: str 1876 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 1877 1878 schema: Optional[str], default None 1879 The schema on which the table resides. 1880 1881 Returns 1882 ------- 1883 A list of `ALTER TABLE` or equivalent queries for the database flavor. 1884 """ 1885 old_table_name = sql_item_name(old_table, flavor, schema) 1886 new_table_name = sql_item_name(new_table, flavor, None) 1887 tmp_table = '_tmp_rename_' + new_table 1888 tmp_table_name = sql_item_name(tmp_table, flavor, schema) 1889 if flavor == 'mssql': 1890 return [f"EXEC sp_rename '{old_table}', '{new_table}'"] 1891 1892 if_exists_str = "IF EXISTS" if flavor in DROP_IF_EXISTS_FLAVORS else "" 1893 if flavor == 'duckdb': 1894 return ( 1895 get_create_table_queries( 1896 f"SELECT * FROM {old_table_name}", 1897 tmp_table, 1898 'duckdb', 1899 schema, 1900 ) + get_create_table_queries( 1901 f"SELECT * FROM {tmp_table_name}", 1902 new_table, 1903 'duckdb', 1904 schema, 1905 ) + [ 1906 f"DROP TABLE {if_exists_str} {tmp_table_name}", 1907 f"DROP TABLE {if_exists_str} {old_table_name}", 1908 ] 1909 ) 1910 1911 return [f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}"]
Return queries to alter a table's name.
Parameters
- old_table (str): The unquoted name of the old table.
- new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
. - schema (Optional[str], default None): The schema on which the table resides.
Returns
- A list of
ALTER TABLE
or equivalent queries for the database flavor.
1914def get_create_table_query( 1915 query_or_dtypes: Union[str, Dict[str, str]], 1916 new_table: str, 1917 flavor: str, 1918 schema: Optional[str] = None, 1919) -> str: 1920 """ 1921 NOTE: This function is deprecated. Use `get_create_table_queries()` instead. 1922 1923 Return a query to create a new table from a `SELECT` query. 1924 1925 Parameters 1926 ---------- 1927 query: Union[str, Dict[str, str]] 1928 The select query to use for the creation of the table. 1929 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 1930 1931 new_table: str 1932 The unquoted name of the new table. 1933 1934 flavor: str 1935 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 1936 1937 schema: Optional[str], default None 1938 The schema on which the table will reside. 1939 1940 Returns 1941 ------- 1942 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1943 """ 1944 return get_create_table_queries( 1945 query_or_dtypes, 1946 new_table, 1947 flavor, 1948 schema=schema, 1949 primary_key=None, 1950 )[0]
NOTE: This function is deprecated. Use get_create_table_queries()
instead.
Return a query to create a new table from a SELECT
query.
Parameters
- query (Union[str, Dict[str, str]]):
The select query to use for the creation of the table.
If a dictionary is provided, return a
CREATE TABLE
query from the givendtypes
columns. - new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
). - schema (Optional[str], default None): The schema on which the table will reside.
Returns
- A
CREATE TABLE
(orSELECT INTO
) query for the database flavor.
1953def get_create_table_queries( 1954 query_or_dtypes: Union[str, Dict[str, str]], 1955 new_table: str, 1956 flavor: str, 1957 schema: Optional[str] = None, 1958 primary_key: Optional[str] = None, 1959 primary_key_db_type: Optional[str] = None, 1960 autoincrement: bool = False, 1961 datetime_column: Optional[str] = None, 1962) -> List[str]: 1963 """ 1964 Return a query to create a new table from a `SELECT` query or a `dtypes` dictionary. 1965 1966 Parameters 1967 ---------- 1968 query_or_dtypes: Union[str, Dict[str, str]] 1969 The select query to use for the creation of the table. 1970 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 1971 1972 new_table: str 1973 The unquoted name of the new table. 1974 1975 flavor: str 1976 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 1977 1978 schema: Optional[str], default None 1979 The schema on which the table will reside. 1980 1981 primary_key: Optional[str], default None 1982 If provided, designate this column as the primary key in the new table. 1983 1984 primary_key_db_type: Optional[str], default None 1985 If provided, alter the primary key to this type (to set NOT NULL constraint). 1986 1987 autoincrement: bool, default False 1988 If `True` and `primary_key` is provided, create the `primary_key` column 1989 as an auto-incrementing integer column. 1990 1991 datetime_column: Optional[str], default None 1992 If provided, include this column in the primary key. 1993 Applicable to TimescaleDB only. 1994 1995 Returns 1996 ------- 1997 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 1998 """ 1999 if not isinstance(query_or_dtypes, (str, dict)): 2000 raise TypeError("`query_or_dtypes` must be a query or a dtypes dictionary.") 2001 2002 method = ( 2003 _get_create_table_query_from_cte 2004 if isinstance(query_or_dtypes, str) 2005 else _get_create_table_query_from_dtypes 2006 ) 2007 return method( 2008 query_or_dtypes, 2009 new_table, 2010 flavor, 2011 schema=schema, 2012 primary_key=primary_key, 2013 primary_key_db_type=primary_key_db_type, 2014 autoincrement=(autoincrement and flavor not in SKIP_AUTO_INCREMENT_FLAVORS), 2015 datetime_column=datetime_column, 2016 )
Return a query to create a new table from a SELECT
query or a dtypes
dictionary.
Parameters
- query_or_dtypes (Union[str, Dict[str, str]]):
The select query to use for the creation of the table.
If a dictionary is provided, return a
CREATE TABLE
query from the givendtypes
columns. - new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
). - schema (Optional[str], default None): The schema on which the table will reside.
- primary_key (Optional[str], default None): If provided, designate this column as the primary key in the new table.
- primary_key_db_type (Optional[str], default None): If provided, alter the primary key to this type (to set NOT NULL constraint).
- autoincrement (bool, default False):
If
True
andprimary_key
is provided, create theprimary_key
column as an auto-incrementing integer column. - datetime_column (Optional[str], default None): If provided, include this column in the primary key. Applicable to TimescaleDB only.
Returns
- A
CREATE TABLE
(orSELECT INTO
) query for the database flavor.
2257def wrap_query_with_cte( 2258 sub_query: str, 2259 parent_query: str, 2260 flavor: str, 2261 cte_name: str = "src", 2262) -> str: 2263 """ 2264 Wrap a subquery in a CTE and append an encapsulating query. 2265 2266 Parameters 2267 ---------- 2268 sub_query: str 2269 The query to be referenced. This may itself contain CTEs. 2270 Unless `cte_name` is provided, this will be aliased as `src`. 2271 2272 parent_query: str 2273 The larger query to append which references the subquery. 2274 This must not contain CTEs. 2275 2276 flavor: str 2277 The database flavor, e.g. `'mssql'`. 2278 2279 cte_name: str, default 'src' 2280 The CTE alias, defaults to `src`. 2281 2282 Returns 2283 ------- 2284 An encapsulating query which allows you to treat `sub_query` as a temporary table. 2285 2286 Examples 2287 -------- 2288 2289 ```python 2290 from meerschaum.utils.sql import wrap_query_with_cte 2291 sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo" 2292 parent_query = "SELECT newval * 3 FROM src" 2293 query = wrap_query_with_cte(sub_query, parent_query, 'mssql') 2294 print(query) 2295 # WITH foo AS (SELECT 1 AS val), 2296 # [src] AS ( 2297 # SELECT (val * 2) AS newval FROM foo 2298 # ) 2299 # SELECT newval * 3 FROM src 2300 ``` 2301 2302 """ 2303 import textwrap 2304 sub_query = sub_query.lstrip() 2305 cte_name_quoted = sql_item_name(cte_name, flavor, None) 2306 2307 if flavor in NO_CTE_FLAVORS: 2308 return ( 2309 parent_query 2310 .replace(cte_name_quoted, '--MRSM_SUBQUERY--') 2311 .replace(cte_name, '--MRSM_SUBQUERY--') 2312 .replace('--MRSM_SUBQUERY--', f"(\n{sub_query}\n) AS {cte_name_quoted}") 2313 ) 2314 2315 if sub_query.lstrip().lower().startswith('with '): 2316 final_select_ix = sub_query.lower().rfind('select') 2317 return ( 2318 sub_query[:final_select_ix].rstrip() + ',\n' 2319 + f"{cte_name_quoted} AS (\n" 2320 + ' ' + sub_query[final_select_ix:] 2321 + "\n)\n" 2322 + parent_query 2323 ) 2324 2325 return ( 2326 f"WITH {cte_name_quoted} AS (\n" 2327 f"{textwrap.indent(sub_query, ' ')}\n" 2328 f")\n{parent_query}" 2329 )
Wrap a subquery in a CTE and append an encapsulating query.
Parameters
- sub_query (str):
The query to be referenced. This may itself contain CTEs.
Unless
cte_name
is provided, this will be aliased assrc
. - parent_query (str): The larger query to append which references the subquery. This must not contain CTEs.
- flavor (str):
The database flavor, e.g.
'mssql'
. - cte_name (str, default 'src'):
The CTE alias, defaults to
src
.
Returns
- An encapsulating query which allows you to treat
sub_query
as a temporary table.
Examples
from meerschaum.utils.sql import wrap_query_with_cte
sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo"
parent_query = "SELECT newval * 3 FROM src"
query = wrap_query_with_cte(sub_query, parent_query, 'mssql')
print(query)
# WITH foo AS (SELECT 1 AS val),
# [src] AS (
# SELECT (val * 2) AS newval FROM foo
# )
# SELECT newval * 3 FROM src
2332def format_cte_subquery( 2333 sub_query: str, 2334 flavor: str, 2335 sub_name: str = 'src', 2336 cols_to_select: Union[List[str], str] = '*', 2337) -> str: 2338 """ 2339 Given a subquery, build a wrapper query that selects from the CTE subquery. 2340 2341 Parameters 2342 ---------- 2343 sub_query: str 2344 The subquery to wrap. 2345 2346 flavor: str 2347 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 2348 2349 sub_name: str, default 'src' 2350 If possible, give this name to the CTE (must be unquoted). 2351 2352 cols_to_select: Union[List[str], str], default '' 2353 If specified, choose which columns to select from the CTE. 2354 If a list of strings is provided, each item will be quoted and joined with commas. 2355 If a string is given, assume it is quoted and insert it into the query. 2356 2357 Returns 2358 ------- 2359 A wrapper query that selects from the CTE. 2360 """ 2361 quoted_sub_name = sql_item_name(sub_name, flavor, None) 2362 cols_str = ( 2363 cols_to_select 2364 if isinstance(cols_to_select, str) 2365 else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select]) 2366 ) 2367 parent_query = ( 2368 f"SELECT {cols_str}\n" 2369 f"FROM {quoted_sub_name}" 2370 ) 2371 return wrap_query_with_cte(sub_query, parent_query, flavor, cte_name=sub_name)
Given a subquery, build a wrapper query that selects from the CTE subquery.
Parameters
- sub_query (str): The subquery to wrap.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql'
,'postgresql'
. - sub_name (str, default 'src'): If possible, give this name to the CTE (must be unquoted).
- cols_to_select (Union[List[str], str], default ''): If specified, choose which columns to select from the CTE. If a list of strings is provided, each item will be quoted and joined with commas. If a string is given, assume it is quoted and insert it into the query.
Returns
- A wrapper query that selects from the CTE.
2374def session_execute( 2375 session: 'sqlalchemy.orm.session.Session', 2376 queries: Union[List[str], str], 2377 with_results: bool = False, 2378 debug: bool = False, 2379) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]: 2380 """ 2381 Similar to `SQLConnector.exec_queries()`, execute a list of queries 2382 and roll back when one fails. 2383 2384 Parameters 2385 ---------- 2386 session: sqlalchemy.orm.session.Session 2387 A SQLAlchemy session representing a transaction. 2388 2389 queries: Union[List[str], str] 2390 A query or list of queries to be executed. 2391 If a query fails, roll back the session. 2392 2393 with_results: bool, default False 2394 If `True`, return a list of result objects. 2395 2396 Returns 2397 ------- 2398 A `SuccessTuple` indicating the queries were successfully executed. 2399 If `with_results`, return the `SuccessTuple` and a list of results. 2400 """ 2401 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 2402 if not isinstance(queries, list): 2403 queries = [queries] 2404 successes, msgs, results = [], [], [] 2405 for query in queries: 2406 if debug: 2407 dprint(query) 2408 query_text = sqlalchemy.text(query) 2409 fail_msg = "Failed to execute queries." 2410 try: 2411 result = session.execute(query_text) 2412 query_success = result is not None 2413 query_msg = "Success" if query_success else fail_msg 2414 except Exception as e: 2415 query_success = False 2416 query_msg = f"{fail_msg}\n{e}" 2417 result = None 2418 successes.append(query_success) 2419 msgs.append(query_msg) 2420 results.append(result) 2421 if not query_success: 2422 if debug: 2423 dprint("Rolling back session.") 2424 session.rollback() 2425 break 2426 success, msg = all(successes), '\n'.join(msgs) 2427 if with_results: 2428 return (success, msg), results 2429 return success, msg
Similar to SQLConnector.exec_queries()
, execute a list of queries
and roll back when one fails.
Parameters
- session (sqlalchemy.orm.session.Session): A SQLAlchemy session representing a transaction.
- queries (Union[List[str], str]): A query or list of queries to be executed. If a query fails, roll back the session.
- with_results (bool, default False):
If
True
, return a list of result objects.
Returns
- A
SuccessTuple
indicating the queries were successfully executed. - If
with_results
, return theSuccessTuple
and a list of results.
2432def get_reset_autoincrement_queries( 2433 table: str, 2434 column: str, 2435 connector: mrsm.connectors.SQLConnector, 2436 schema: Optional[str] = None, 2437 debug: bool = False, 2438) -> List[str]: 2439 """ 2440 Return a list of queries to reset a table's auto-increment counter to the next largest value. 2441 2442 Parameters 2443 ---------- 2444 table: str 2445 The name of the table on which the auto-incrementing column exists. 2446 2447 column: str 2448 The name of the auto-incrementing column. 2449 2450 connector: mrsm.connectors.SQLConnector 2451 The SQLConnector to the database on which the table exists. 2452 2453 schema: Optional[str], default None 2454 The schema of the table. Defaults to `connector.schema`. 2455 2456 Returns 2457 ------- 2458 A list of queries to be executed to reset the auto-incrementing column. 2459 """ 2460 if not table_exists(table, connector, schema=schema, debug=debug): 2461 return [] 2462 2463 schema = schema or connector.schema 2464 max_id_name = sql_item_name('max_id', connector.flavor) 2465 table_name = sql_item_name(table, connector.flavor, schema) 2466 table_seq_name = sql_item_name(table + '_' + column + '_seq', connector.flavor, schema) 2467 column_name = sql_item_name(column, connector.flavor) 2468 max_id = connector.value( 2469 f""" 2470 SELECT COALESCE(MAX({column_name}), 0) AS {max_id_name} 2471 FROM {table_name} 2472 """, 2473 debug=debug, 2474 ) 2475 if max_id is None: 2476 return [] 2477 2478 reset_queries = reset_autoincrement_queries.get( 2479 connector.flavor, 2480 reset_autoincrement_queries['default'] 2481 ) 2482 if not isinstance(reset_queries, list): 2483 reset_queries = [reset_queries] 2484 2485 return [ 2486 query.format( 2487 column=column, 2488 column_name=column_name, 2489 table=table, 2490 table_name=table_name, 2491 table_seq_name=table_seq_name, 2492 val=max_id, 2493 val_plus_1=(max_id + 1), 2494 ) 2495 for query in reset_queries 2496 ]
Return a list of queries to reset a table's auto-increment counter to the next largest value.
Parameters
- table (str): The name of the table on which the auto-incrementing column exists.
- column (str): The name of the auto-incrementing column.
- connector (mrsm.connectors.SQLConnector): The SQLConnector to the database on which the table exists.
- schema (Optional[str], default None):
The schema of the table. Defaults to
connector.schema
.
Returns
- A list of queries to be executed to reset the auto-incrementing column.