meerschaum.utils.sql
Flavor-specific SQL tools.
1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# vim:fenc=utf-8 4 5""" 6Flavor-specific SQL tools. 7""" 8 9from __future__ import annotations 10 11from datetime import datetime, timezone, timedelta 12import meerschaum as mrsm 13from meerschaum.utils.typing import Optional, Dict, Any, Union, List, Iterable, Tuple 14### Preserve legacy imports. 15from meerschaum.utils.dtypes.sql import ( 16 DB_TO_PD_DTYPES, 17 PD_TO_DB_DTYPES_FLAVORS, 18 get_pd_type_from_db_type as get_pd_type, 19 get_db_type_from_pd_type as get_db_type, 20 TIMEZONE_NAIVE_FLAVORS, 21) 22from meerschaum.utils.warnings import warn 23from meerschaum.utils.debug import dprint 24 25test_queries = { 26 'default' : 'SELECT 1', 27 'oracle' : 'SELECT 1 FROM DUAL', 28 'informix' : 'SELECT COUNT(*) FROM systables', 29 'hsqldb' : 'SELECT 1 FROM INFORMATION_SCHEMA.SYSTEM_USERS', 30} 31### `table_name` is the escaped name of the table. 32### `table` is the unescaped name of the table. 33exists_queries = { 34 'default': "SELECT COUNT(*) FROM {table_name} WHERE 1 = 0", 35} 36version_queries = { 37 'default': "SELECT VERSION() AS {version_name}", 38 'sqlite': "SELECT SQLITE_VERSION() AS {version_name}", 39 'geopackage': "SELECT SQLITE_VERSION() AS {version_name}", 40 'mssql': "SELECT @@version", 41 'oracle': "SELECT version from PRODUCT_COMPONENT_VERSION WHERE rownum = 1", 42} 43SKIP_IF_EXISTS_FLAVORS = {'mssql', 'oracle'} 44DROP_IF_EXISTS_FLAVORS = { 45 'timescaledb', 46 'timescaledb-ha', 47 'postgresql', 48 'postgis', 49 'citus', 50 'mssql', 51 'mysql', 52 'mariadb', 53 'sqlite', 54 'geopackage', 55} 56DROP_INDEX_IF_EXISTS_FLAVORS = { 57 'mssql', 58 'timescaledb', 59 'timescaledb-ha', 60 'postgresql', 61 'postgis', 62 'sqlite', 63 'geopackage', 64 'citus', 65} 66SKIP_AUTO_INCREMENT_FLAVORS = {'citus', 'duckdb'} 67COALESCE_UNIQUE_INDEX_FLAVORS = { 68 'timescaledb', 69 'timescaledb-ha', 70 'postgresql', 71 'postgis', 72 'citus', 73} 74UPDATE_QUERIES = { 75 'default': """ 76 UPDATE {target_table_name} AS f 77 {sets_subquery_none} 78 FROM {target_table_name} AS t 79 INNER JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p 80 ON 81 {and_subquery_t} 82 WHERE 83 {and_subquery_f} 84 AND 85 {date_bounds_subquery} 86 """, 87 'timescaledb-upsert': """ 88 INSERT INTO {target_table_name} ({patch_cols_str}) 89 SELECT {patch_cols_str} 90 FROM {patch_table_name} 91 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 92 """, 93 'timescaledb-ha-upsert': """ 94 INSERT INTO {target_table_name} ({patch_cols_str}) 95 SELECT {patch_cols_str} 96 FROM {patch_table_name} 97 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 98 """, 99 'postgresql-upsert': """ 100 INSERT INTO {target_table_name} ({patch_cols_str}) 101 SELECT {patch_cols_str} 102 FROM {patch_table_name} 103 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 104 """, 105 'postgis-upsert': """ 106 INSERT INTO {target_table_name} ({patch_cols_str}) 107 SELECT {patch_cols_str} 108 FROM {patch_table_name} 109 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 110 """, 111 'citus-upsert': """ 112 INSERT INTO {target_table_name} ({patch_cols_str}) 113 SELECT {patch_cols_str} 114 FROM {patch_table_name} 115 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 116 """, 117 'cockroachdb-upsert': """ 118 INSERT INTO {target_table_name} ({patch_cols_str}) 119 SELECT {patch_cols_str} 120 FROM {patch_table_name} 121 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 122 """, 123 'mysql': """ 124 UPDATE {target_table_name} AS f 125 JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p 126 ON 127 {and_subquery_f} 128 {sets_subquery_f} 129 WHERE 130 {date_bounds_subquery} 131 """, 132 'mysql-upsert': """ 133 INSERT {ignore}INTO {target_table_name} ({patch_cols_str}) 134 SELECT {patch_cols_str} 135 FROM {patch_table_name} 136 {on_duplicate_key_update} 137 {cols_equal_values} 138 """, 139 'mariadb': """ 140 UPDATE {target_table_name} AS f 141 JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p 142 ON 143 {and_subquery_f} 144 {sets_subquery_f} 145 WHERE 146 {date_bounds_subquery} 147 """, 148 'mariadb-upsert': """ 149 INSERT {ignore}INTO {target_table_name} ({patch_cols_str}) 150 SELECT {patch_cols_str} 151 FROM {patch_table_name} 152 {on_duplicate_key_update} 153 {cols_equal_values} 154 """, 155 'mssql': """ 156 {with_temp_date_bounds} 157 MERGE {target_table_name} f 158 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 159 ON 160 {and_subquery_f} 161 AND 162 {date_bounds_subquery} 163 WHEN MATCHED THEN 164 UPDATE 165 {sets_subquery_none}; 166 """, 167 'mssql-upsert': [ 168 "{identity_insert_on}", 169 """ 170 {with_temp_date_bounds} 171 MERGE {target_table_name} f 172 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 173 ON 174 {and_subquery_f} 175 AND 176 {date_bounds_subquery}{when_matched_update_sets_subquery_none} 177 WHEN NOT MATCHED THEN 178 INSERT ({patch_cols_str}) 179 VALUES ({patch_cols_prefixed_str}); 180 """, 181 "{identity_insert_off}", 182 ], 183 'oracle': """ 184 MERGE INTO {target_table_name} f 185 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 186 ON ( 187 {and_subquery_f} 188 AND 189 {date_bounds_subquery} 190 ) 191 WHEN MATCHED THEN 192 UPDATE 193 {sets_subquery_none} 194 """, 195 'oracle-upsert': """ 196 MERGE INTO {target_table_name} f 197 USING (SELECT {patch_cols_str} FROM {patch_table_name}) p 198 ON ( 199 {and_subquery_f} 200 AND 201 {date_bounds_subquery} 202 ){when_matched_update_sets_subquery_none} 203 WHEN NOT MATCHED THEN 204 INSERT ({patch_cols_str}) 205 VALUES ({patch_cols_prefixed_str}) 206 """, 207 'sqlite-upsert': """ 208 INSERT INTO {target_table_name} ({patch_cols_str}) 209 SELECT {patch_cols_str} 210 FROM {patch_table_name} 211 WHERE true 212 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 213 """, 214 'sqlite_delete_insert': [ 215 """ 216 DELETE FROM {target_table_name} AS f 217 WHERE ROWID IN ( 218 SELECT t.ROWID 219 FROM {target_table_name} AS t 220 INNER JOIN (SELECT * FROM {patch_table_name}) AS p 221 ON {and_subquery_t} 222 ); 223 """, 224 """ 225 INSERT INTO {target_table_name} AS f 226 SELECT {patch_cols_str} FROM {patch_table_name} AS p 227 """, 228 ], 229 'geopackage-upsert': """ 230 INSERT INTO {target_table_name} ({patch_cols_str}) 231 SELECT {patch_cols_str} 232 FROM {patch_table_name} 233 WHERE true 234 ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded} 235 """, 236} 237columns_types_queries = { 238 'default': """ 239 SELECT 240 table_catalog AS database, 241 table_schema AS schema, 242 table_name AS table, 243 column_name AS column, 244 data_type AS type, 245 numeric_precision, 246 numeric_scale 247 FROM information_schema.columns 248 WHERE table_name IN ('{table}', '{table_trunc}') 249 """, 250 'sqlite': """ 251 SELECT 252 '' "database", 253 '' "schema", 254 m.name "table", 255 p.name "column", 256 p.type "type" 257 FROM sqlite_master m 258 LEFT OUTER JOIN pragma_table_info(m.name) p 259 ON m.name <> p.name 260 WHERE m.type = 'table' 261 AND m.name IN ('{table}', '{table_trunc}') 262 """, 263 'geopackage': """ 264 SELECT 265 '' "database", 266 '' "schema", 267 m.name "table", 268 p.name "column", 269 p.type "type" 270 FROM sqlite_master m 271 LEFT OUTER JOIN pragma_table_info(m.name) p 272 ON m.name <> p.name 273 WHERE m.type = 'table' 274 AND m.name IN ('{table}', '{table_trunc}') 275 """, 276 'mssql': """ 277 SELECT 278 TABLE_CATALOG AS [database], 279 TABLE_SCHEMA AS [schema], 280 TABLE_NAME AS [table], 281 COLUMN_NAME AS [column], 282 CASE WHEN CHARACTER_MAXIMUM_LENGTH = -1 THEN DATA_TYPE + '(max)' ELSE DATA_TYPE END AS [type], 283 NUMERIC_PRECISION AS [numeric_precision], 284 NUMERIC_SCALE AS [numeric_scale] 285 FROM {db_prefix}INFORMATION_SCHEMA.COLUMNS WITH (NOLOCK) 286 WHERE TABLE_NAME IN ( 287 '{table}', 288 '{table_trunc}' 289 ) 290 291 """, 292 'mysql': """ 293 SELECT 294 TABLE_SCHEMA `database`, 295 TABLE_SCHEMA `schema`, 296 TABLE_NAME `table`, 297 COLUMN_NAME `column`, 298 DATA_TYPE `type`, 299 NUMERIC_PRECISION `numeric_precision`, 300 NUMERIC_SCALE `numeric_scale` 301 FROM INFORMATION_SCHEMA.COLUMNS 302 WHERE TABLE_NAME IN ('{table}', '{table_trunc}') 303 """, 304 'mariadb': """ 305 SELECT 306 TABLE_SCHEMA `database`, 307 TABLE_SCHEMA `schema`, 308 TABLE_NAME `table`, 309 COLUMN_NAME `column`, 310 DATA_TYPE `type`, 311 NUMERIC_PRECISION `numeric_precision`, 312 NUMERIC_SCALE `numeric_scale` 313 FROM INFORMATION_SCHEMA.COLUMNS 314 WHERE TABLE_NAME IN ('{table}', '{table_trunc}') 315 """, 316 'oracle': """ 317 SELECT 318 NULL AS "database", 319 NULL AS "schema", 320 TABLE_NAME AS "table", 321 COLUMN_NAME AS "column", 322 DATA_TYPE AS "type", 323 DATA_PRECISION AS "numeric_precision", 324 DATA_SCALE AS "numeric_scale" 325 FROM all_tab_columns 326 WHERE TABLE_NAME IN ( 327 '{table}', 328 '{table_trunc}', 329 '{table_lower}', 330 '{table_lower_trunc}', 331 '{table_upper}', 332 '{table_upper_trunc}' 333 ) 334 """, 335} 336hypertable_queries = { 337 'timescaledb': 'SELECT hypertable_size(\'{table_name}\')', 338 'timescaledb-ha': 'SELECT hypertable_size(\'{table_name}\')', 339 'citus': 'SELECT citus_table_size(\'{table_name}\')', 340} 341columns_indices_queries = { 342 'default': """ 343 SELECT 344 current_database() AS "database", 345 n.nspname AS "schema", 346 t.relname AS "table", 347 c.column_name AS "column", 348 i.relname AS "index", 349 CASE WHEN con.contype = 'p' THEN 'PRIMARY KEY' ELSE 'INDEX' END AS "index_type" 350 FROM pg_class t 351 INNER JOIN pg_index AS ix 352 ON t.oid = ix.indrelid 353 INNER JOIN pg_class AS i 354 ON i.oid = ix.indexrelid 355 INNER JOIN pg_namespace AS n 356 ON n.oid = t.relnamespace 357 INNER JOIN pg_attribute AS a 358 ON a.attnum = ANY(ix.indkey) 359 AND a.attrelid = t.oid 360 INNER JOIN information_schema.columns AS c 361 ON c.column_name = a.attname 362 AND c.table_name = t.relname 363 AND c.table_schema = n.nspname 364 LEFT JOIN pg_constraint AS con 365 ON con.conindid = i.oid 366 AND con.contype = 'p' 367 WHERE 368 t.relname IN ('{table}', '{table_trunc}') 369 AND n.nspname = '{schema}' 370 """, 371 'sqlite': """ 372 WITH indexed_columns AS ( 373 SELECT 374 '{table}' AS table_name, 375 pi.name AS column_name, 376 i.name AS index_name, 377 'INDEX' AS index_type 378 FROM 379 sqlite_master AS i, 380 pragma_index_info(i.name) AS pi 381 WHERE 382 i.type = 'index' 383 AND i.tbl_name = '{table}' 384 ), 385 primary_key_columns AS ( 386 SELECT 387 '{table}' AS table_name, 388 ti.name AS column_name, 389 'PRIMARY_KEY' AS index_name, 390 'PRIMARY KEY' AS index_type 391 FROM 392 pragma_table_info('{table}') AS ti 393 WHERE 394 ti.pk > 0 395 ) 396 SELECT 397 NULL AS "database", 398 NULL AS "schema", 399 "table_name" AS "table", 400 "column_name" AS "column", 401 "index_name" AS "index", 402 "index_type" 403 FROM indexed_columns 404 UNION ALL 405 SELECT 406 NULL AS "database", 407 NULL AS "schema", 408 table_name AS "table", 409 column_name AS "column", 410 index_name AS "index", 411 index_type 412 FROM primary_key_columns 413 """, 414 'geopackage': """ 415 WITH indexed_columns AS ( 416 SELECT 417 '{table}' AS table_name, 418 pi.name AS column_name, 419 i.name AS index_name, 420 'INDEX' AS index_type 421 FROM 422 sqlite_master AS i, 423 pragma_index_info(i.name) AS pi 424 WHERE 425 i.type = 'index' 426 AND i.tbl_name = '{table}' 427 ), 428 primary_key_columns AS ( 429 SELECT 430 '{table}' AS table_name, 431 ti.name AS column_name, 432 'PRIMARY_KEY' AS index_name, 433 'PRIMARY KEY' AS index_type 434 FROM 435 pragma_table_info('{table}') AS ti 436 WHERE 437 ti.pk > 0 438 ) 439 SELECT 440 NULL AS "database", 441 NULL AS "schema", 442 "table_name" AS "table", 443 "column_name" AS "column", 444 "index_name" AS "index", 445 "index_type" 446 FROM indexed_columns 447 UNION ALL 448 SELECT 449 NULL AS "database", 450 NULL AS "schema", 451 table_name AS "table", 452 column_name AS "column", 453 index_name AS "index", 454 index_type 455 FROM primary_key_columns 456 """, 457 'mssql': """ 458 SELECT 459 NULL AS [database], 460 s.name AS [schema], 461 t.name AS [table], 462 c.name AS [column], 463 i.name AS [index], 464 CASE 465 WHEN kc.type = 'PK' THEN 'PRIMARY KEY' 466 ELSE 'INDEX' 467 END AS [index_type], 468 CASE 469 WHEN i.type = 1 THEN CAST(1 AS BIT) 470 ELSE CAST(0 AS BIT) 471 END AS [clustered] 472 FROM 473 sys.schemas s WITH (NOLOCK) 474 INNER JOIN sys.tables t WITH (NOLOCK) 475 ON s.schema_id = t.schema_id 476 INNER JOIN sys.indexes i WITH (NOLOCK) 477 ON t.object_id = i.object_id 478 INNER JOIN sys.index_columns ic WITH (NOLOCK) 479 ON i.object_id = ic.object_id 480 AND i.index_id = ic.index_id 481 INNER JOIN sys.columns c WITH (NOLOCK) 482 ON ic.object_id = c.object_id 483 AND ic.column_id = c.column_id 484 LEFT JOIN sys.key_constraints kc WITH (NOLOCK) 485 ON kc.parent_object_id = i.object_id 486 AND kc.type = 'PK' 487 AND kc.name = i.name 488 WHERE 489 t.name IN ('{table}', '{table_trunc}') 490 AND s.name = '{schema}' 491 AND i.type IN (1, 2) 492 """, 493 'oracle': """ 494 SELECT 495 NULL AS "database", 496 ic.table_owner AS "schema", 497 ic.table_name AS "table", 498 ic.column_name AS "column", 499 i.index_name AS "index", 500 CASE 501 WHEN c.constraint_type = 'P' THEN 'PRIMARY KEY' 502 WHEN i.uniqueness = 'UNIQUE' THEN 'UNIQUE INDEX' 503 ELSE 'INDEX' 504 END AS index_type 505 FROM 506 all_ind_columns ic 507 INNER JOIN all_indexes i 508 ON ic.index_name = i.index_name 509 AND ic.table_owner = i.owner 510 LEFT JOIN all_constraints c 511 ON i.index_name = c.constraint_name 512 AND i.table_owner = c.owner 513 AND c.constraint_type = 'P' 514 WHERE ic.table_name IN ( 515 '{table}', 516 '{table_trunc}', 517 '{table_upper}', 518 '{table_upper_trunc}' 519 ) 520 """, 521 'mysql': """ 522 SELECT 523 TABLE_SCHEMA AS `database`, 524 TABLE_SCHEMA AS `schema`, 525 TABLE_NAME AS `table`, 526 COLUMN_NAME AS `column`, 527 INDEX_NAME AS `index`, 528 CASE 529 WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY' 530 ELSE 'INDEX' 531 END AS `index_type` 532 FROM 533 information_schema.STATISTICS 534 WHERE 535 TABLE_NAME IN ('{table}', '{table_trunc}') 536 """, 537 'mariadb': """ 538 SELECT 539 TABLE_SCHEMA AS `database`, 540 TABLE_SCHEMA AS `schema`, 541 TABLE_NAME AS `table`, 542 COLUMN_NAME AS `column`, 543 INDEX_NAME AS `index`, 544 CASE 545 WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY' 546 ELSE 'INDEX' 547 END AS `index_type` 548 FROM 549 information_schema.STATISTICS 550 WHERE 551 TABLE_NAME IN ('{table}', '{table_trunc}') 552 """, 553} 554reset_autoincrement_queries: Dict[str, Union[str, List[str]]] = { 555 'default': """ 556 SELECT SETVAL(pg_get_serial_sequence('{table_name}', '{column}'), {val}) 557 FROM {table_name} 558 """, 559 'mssql': """ 560 DBCC CHECKIDENT ('{table_name}', RESEED, {val}) 561 """, 562 'mysql': """ 563 ALTER TABLE {table_name} AUTO_INCREMENT = {val} 564 """, 565 'mariadb': """ 566 ALTER TABLE {table_name} AUTO_INCREMENT = {val} 567 """, 568 'sqlite': """ 569 UPDATE sqlite_sequence 570 SET seq = {val} 571 WHERE name = '{table}' 572 """, 573 'geopackage': """ 574 UPDATE sqlite_sequence 575 SET seq = {val} 576 WHERE name = '{table}' 577 """, 578 'oracle': ( 579 "ALTER TABLE {table_name} MODIFY {column_name} " 580 "GENERATED BY DEFAULT ON NULL AS IDENTITY (START WITH {val_plus_1})" 581 ), 582} 583table_wrappers = { 584 'default' : ('"', '"'), 585 'timescaledb' : ('"', '"'), 586 'timescaledb-ha': ('"', '"'), 587 'citus' : ('"', '"'), 588 'duckdb' : ('"', '"'), 589 'postgresql' : ('"', '"'), 590 'postgis' : ('"', '"'), 591 'sqlite' : ('"', '"'), 592 'geopackage' : ('"', '"'), 593 'mysql' : ('`', '`'), 594 'mariadb' : ('`', '`'), 595 'mssql' : ('[', ']'), 596 'cockroachdb' : ('"', '"'), 597 'oracle' : ('"', '"'), 598} 599max_name_lens = { 600 'default' : 64, 601 'mssql' : 128, 602 'oracle' : 30, 603 'postgresql' : 64, 604 'postgis' : 64, 605 'timescaledb' : 64, 606 'timescaledb-ha': 64, 607 'citus' : 64, 608 'cockroachdb' : 64, 609 'sqlite' : 1024, 610 'geopackage' : 1024, 611 'mysql' : 64, 612 'mariadb' : 64, 613} 614json_flavors = { 615 'postgresql', 616 'postgis', 617 'timescaledb', 618 'timescaledb-ha', 619 'citus', 620 'cockroachdb', 621} 622NO_SCHEMA_FLAVORS = { 623 'oracle', 624 'sqlite', 625 'geopackage', 626 'mysql', 627 'mariadb', 628 'duckdb', 629} 630DEFAULT_SCHEMA_FLAVORS = { 631 'postgresql': 'public', 632 'postgis': 'public', 633 'timescaledb': 'public', 634 'timescaledb-ha': 'public', 635 'citus': 'public', 636 'cockroachdb': 'public', 637 'mysql': 'mysql', 638 'mariadb': 'mysql', 639 'mssql': 'dbo', 640} 641OMIT_NULLSFIRST_FLAVORS = { 642 'mariadb', 643 'mysql', 644 'mssql', 645} 646 647SINGLE_ALTER_TABLE_FLAVORS = { 648 'duckdb', 649 'sqlite', 650 'geopackage', 651 'mssql', 652 'oracle', 653} 654NO_CTE_FLAVORS = { 655 'mysql', 656 'mariadb', 657} 658NO_SELECT_INTO_FLAVORS = { 659 'sqlite', 660 'geopackage', 661 'oracle', 662 'mysql', 663 'mariadb', 664 'duckdb', 665} 666 667 668def clean(substring: str) -> None: 669 """ 670 Ensure a substring is clean enough to be inserted into a SQL query. 671 Raises an exception when banned words are used. 672 """ 673 from meerschaum.utils.warnings import error 674 banned_symbols = [ 675 ';', '--', 'drop ', '/*', '*/', 676 'union ', 'exec ', 'execute ', 677 'insert ', 'update ', 'delete ', 678 'create ', 'alter ', 'truncate ', 679 ] 680 for symbol in banned_symbols: 681 if symbol in str(substring).lower(): 682 error(f"Invalid string: '{substring}'") 683 684 685_VALID_DATEPARTS = frozenset({'year', 'month', 'day', 'hour', 'minute', 'second'}) 686 687 688def dateadd_str( 689 flavor: str = 'postgresql', 690 datepart: str = 'day', 691 number: Union[int, float] = 0, 692 begin: Union[str, datetime, int] = 'now', 693 db_type: Optional[str] = None, 694) -> str: 695 """ 696 Generate a `DATEADD` clause depending on database flavor. 697 698 Parameters 699 ---------- 700 flavor: str, default `'postgresql'` 701 SQL database flavor, e.g. `'postgresql'`, `'sqlite'`. 702 703 Currently supported flavors: 704 705 - `'postgresql'` 706 - `'postgis'` 707 - `'timescaledb'` 708 - `'timescaledb-ha'` 709 - `'citus'` 710 - `'cockroachdb'` 711 - `'duckdb'` 712 - `'mssql'` 713 - `'mysql'` 714 - `'mariadb'` 715 - `'sqlite'` 716 - `'geopackage'` 717 - `'oracle'` 718 719 datepart: str, default `'day'` 720 Which part of the date to modify. Supported values: 721 722 - `'year'` 723 - `'month'` 724 - `'day'` 725 - `'hour'` 726 - `'minute'` 727 - `'second'` 728 729 number: Union[int, float], default `0` 730 How many units to add to the date part. 731 732 begin: Union[str, datetime], default `'now'` 733 Base datetime to which to add dateparts. 734 735 db_type: Optional[str], default None 736 If provided, cast the datetime string as the type. 737 Otherwise, infer this from the input datetime value. 738 739 Returns 740 ------- 741 The appropriate `DATEADD` string for the corresponding database flavor. 742 743 Examples 744 -------- 745 >>> dateadd_str( 746 ... flavor='mssql', 747 ... begin=datetime(2022, 1, 1, 0, 0), 748 ... number=1, 749 ... ) 750 "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))" 751 >>> dateadd_str( 752 ... flavor='postgresql', 753 ... begin=datetime(2022, 1, 1, 0, 0), 754 ... number=1, 755 ... ) 756 "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'" 757 758 """ 759 from meerschaum.utils.packages import attempt_import 760 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, get_pd_type_from_db_type 761 dateutil_parser = attempt_import('dateutil.parser') 762 if 'int' in str(type(begin)).lower(): 763 num_str = str(begin) 764 if number is not None and number != 0: 765 num_str += ( 766 f' + {number}' 767 if number > 0 768 else f" - {number * -1}" 769 ) 770 return num_str 771 if datepart not in _VALID_DATEPARTS: 772 raise ValueError( 773 f"Invalid datepart '{datepart}'. Must be one of: {sorted(_VALID_DATEPARTS)}" 774 ) 775 if not begin: 776 return '' 777 778 _original_begin = begin 779 begin_time = None 780 ### Sanity check: make sure `begin` is a valid datetime before we inject anything. 781 if not isinstance(begin, datetime): 782 try: 783 begin_time = dateutil_parser.parse(begin) 784 except Exception: 785 begin_time = None 786 else: 787 begin_time = begin 788 789 ### Unable to parse into a datetime. 790 if begin_time is None: 791 ### Throw an error if banned symbols are included in the `begin` string. 792 clean(str(begin)) 793 ### If begin is a valid datetime, wrap it in quotes. 794 else: 795 if isinstance(begin, datetime) and begin.tzinfo is not None: 796 begin = begin.astimezone(timezone.utc) 797 begin = ( 798 f"'{begin.replace(tzinfo=None)}'" 799 if isinstance(begin, datetime) and flavor in TIMEZONE_NAIVE_FLAVORS 800 else f"'{begin}'" 801 ) 802 803 dt_is_utc = ( 804 begin_time.tzinfo is not None 805 if begin_time is not None 806 else ('+' in str(begin) or '-' in str(begin).split(':', maxsplit=1)[-1]) 807 ) 808 if db_type: 809 db_type_is_utc = 'utc' in get_pd_type_from_db_type(db_type).lower() 810 dt_is_utc = dt_is_utc or db_type_is_utc 811 db_type = db_type or get_db_type_from_pd_type( 812 ('datetime64[ns, UTC]' if dt_is_utc else 'datetime64[ns]'), 813 flavor=flavor, 814 ) 815 816 da = "" 817 if flavor in ( 818 'postgresql', 819 'postgis', 820 'timescaledb', 821 'timescaledb-ha', 822 'cockroachdb', 823 'citus', 824 ): 825 begin = ( 826 f"CAST({begin} AS {db_type})" if begin != 'now' 827 else f"CAST(NOW() AT TIME ZONE 'utc' AS {db_type})" 828 ) 829 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 830 831 elif flavor == 'duckdb': 832 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'NOW()' 833 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 834 835 elif flavor in ('mssql',): 836 if begin_time and begin_time.microsecond != 0 and not dt_is_utc: 837 begin = begin[:-4] + "'" 838 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'GETUTCDATE()' 839 if dt_is_utc: 840 begin += " AT TIME ZONE 'UTC'" 841 da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin 842 843 elif flavor in ('mysql', 'mariadb'): 844 begin = ( 845 f"CAST({begin} AS DATETIME(6))" 846 if begin != 'now' 847 else 'UTC_TIMESTAMP(6)' 848 ) 849 da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin) 850 851 elif flavor in ('sqlite', 'geopackage'): 852 da = f"datetime({begin}, '{number} {datepart}')" 853 854 elif flavor == 'oracle': 855 if begin == 'now': 856 begin = str( 857 datetime.now(timezone.utc).replace(tzinfo=None).strftime(r'%Y:%m:%d %M:%S.%f') 858 ) 859 elif begin_time: 860 begin = str(begin_time.strftime(r'%Y-%m-%d %H:%M:%S.%f')) 861 dt_format = 'YYYY-MM-DD HH24:MI:SS.FF' 862 _begin = f"'{begin}'" if begin_time else begin 863 da = ( 864 (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin) 865 + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "") 866 ) 867 return da 868 869 870def test_connection( 871 self, 872 **kw: Any 873) -> Union[bool, None]: 874 """ 875 Test if a successful connection to the database may be made. 876 877 Parameters 878 ---------- 879 **kw: 880 The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`. 881 882 Returns 883 ------- 884 `True` if a connection is made, otherwise `False` or `None` in case of failure. 885 886 """ 887 import warnings 888 from meerschaum.connectors.poll import retry_connect 889 _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self} 890 _default_kw.update(kw) 891 with warnings.catch_warnings(): 892 warnings.filterwarnings('ignore', 'Could not') 893 try: 894 return retry_connect(**_default_kw) 895 except Exception: 896 return False 897 898 899def get_distinct_col_count( 900 col: str, 901 query: str, 902 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 903 debug: bool = False 904) -> Optional[int]: 905 """ 906 Returns the number of distinct items in a column of a SQL query. 907 908 Parameters 909 ---------- 910 col: str: 911 The column in the query to count. 912 913 query: str: 914 The SQL query to count from. 915 916 connector: Optional[mrsm.connectors.sql.SQLConnector], default None: 917 The SQLConnector to execute the query. 918 919 debug: bool, default False: 920 Verbosity toggle. 921 922 Returns 923 ------- 924 An `int` of the number of columns in the query or `None` if the query fails. 925 926 """ 927 if connector is None: 928 connector = mrsm.get_connector('sql') 929 930 _col_name = sql_item_name(col, connector.flavor, None) 931 932 _meta_query = ( 933 f""" 934 WITH src AS ( {query} ), 935 dist AS ( SELECT DISTINCT {_col_name} FROM src ) 936 SELECT COUNT(*) FROM dist""" 937 ) if connector.flavor not in ('mysql', 'mariadb') else ( 938 f""" 939 SELECT COUNT(*) 940 FROM ( 941 SELECT DISTINCT {_col_name} 942 FROM ({query}) AS src 943 ) AS dist""" 944 ) 945 946 result = connector.value(_meta_query, debug=debug) 947 try: 948 return int(result) 949 except Exception: 950 return None 951 952 953def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str: 954 """ 955 Parse SQL items depending on the flavor. 956 957 Parameters 958 ---------- 959 item: str 960 The database item (table, view, etc.) in need of quotes. 961 962 flavor: str 963 The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.). 964 965 schema: Optional[str], default None 966 If provided, prefix the table name with the schema. 967 968 Returns 969 ------- 970 A `str` which contains the input `item` wrapped in the corresponding escape characters. 971 972 Examples 973 -------- 974 >>> sql_item_name('table', 'sqlite') 975 '"table"' 976 >>> sql_item_name('table', 'mssql') 977 "[table]" 978 >>> sql_item_name('table', 'postgresql', schema='abc') 979 '"abc"."table"' 980 981 """ 982 truncated_item = truncate_item_name(str(item), flavor) 983 if flavor == 'oracle': 984 truncated_item = pg_capital(truncated_item, quote_capitals=True) 985 ### NOTE: System-reserved words must be quoted. 986 if truncated_item.lower() in ( 987 'float', 'varchar', 'nvarchar', 'clob', 988 'boolean', 'integer', 'table', 'row', 'date', 'time', 989 ): 990 wrappers = ('"', '"') 991 else: 992 wrappers = ('', '') 993 else: 994 wrappers = table_wrappers.get(flavor, table_wrappers['default']) 995 996 ### NOTE: SQLite does not support schemas. 997 if flavor in ('sqlite', 'geopackage'): 998 schema = None 999 elif flavor == 'mssql' and str(item).startswith('#'): 1000 schema = None 1001 1002 schema_prefix = ( 1003 (wrappers[0] + schema + wrappers[1] + '.') 1004 if schema is not None 1005 else '' 1006 ) 1007 1008 return schema_prefix + wrappers[0] + truncated_item + wrappers[1] 1009 1010 1011def pg_capital(s: str, quote_capitals: bool = True) -> str: 1012 """ 1013 If string contains a capital letter, wrap it in double quotes. 1014 1015 Parameters 1016 ---------- 1017 s: str 1018 The string to be escaped. 1019 1020 quote_capitals: bool, default True 1021 If `False`, do not quote strings with contain only a mix of capital and lower-case letters. 1022 1023 Returns 1024 ------- 1025 The input string wrapped in quotes only if it needs them. 1026 1027 Examples 1028 -------- 1029 >>> pg_capital("My Table") 1030 '"My Table"' 1031 >>> pg_capital('my_table') 1032 'my_table' 1033 1034 """ 1035 if s.startswith('"') and s.endswith('"'): 1036 return s 1037 1038 s = s.replace('"', '') 1039 1040 needs_quotes = s.startswith('_') 1041 if not needs_quotes: 1042 for c in s: 1043 if c == '_': 1044 continue 1045 1046 if not c.isalnum() or (quote_capitals and c.isupper()): 1047 needs_quotes = True 1048 break 1049 1050 if needs_quotes: 1051 return '"' + s + '"' 1052 1053 return s 1054 1055 1056def oracle_capital(s: str) -> str: 1057 """ 1058 Capitalize the string of an item on an Oracle database. 1059 """ 1060 return s 1061 1062 1063def truncate_item_name(item: str, flavor: str) -> str: 1064 """ 1065 Truncate item names to stay within the database flavor's character limit. 1066 1067 Parameters 1068 ---------- 1069 item: str 1070 The database item being referenced. This string is the "canonical" name internally. 1071 1072 flavor: str 1073 The flavor of the database on which `item` resides. 1074 1075 Returns 1076 ------- 1077 The truncated string. 1078 """ 1079 from meerschaum.utils.misc import truncate_string_sections 1080 return truncate_string_sections( 1081 item, max_len=max_name_lens.get(flavor, max_name_lens['default']) 1082 ) 1083 1084 1085def build_where( 1086 params: Dict[str, Any], 1087 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 1088 with_where: bool = True, 1089 flavor: str = 'postgresql', 1090) -> str: 1091 """ 1092 Build the `WHERE` clause based on the input criteria. 1093 1094 Parameters 1095 ---------- 1096 params: Dict[str, Any]: 1097 The keywords dictionary to convert into a WHERE clause. 1098 If a value is a string which begins with an underscore, negate that value 1099 (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`). 1100 A value of `_None` will be interpreted as `IS NOT NULL`. 1101 1102 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 1103 The Meerschaum SQLConnector that will be executing the query. 1104 The connector is used to extract the SQL dialect. 1105 1106 with_where: bool, default True: 1107 If `True`, include the leading `'WHERE'` string. 1108 1109 flavor: str, default 'postgresql' 1110 If `connector` is `None`, fall back to this flavor. 1111 1112 Returns 1113 ------- 1114 A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor. 1115 1116 Examples 1117 -------- 1118 ``` 1119 >>> print(build_where({'foo': [1, 2, 3]})) 1120 1121 WHERE 1122 "foo" IN ('1', '2', '3') 1123 ``` 1124 """ 1125 import json 1126 from meerschaum._internal.static import STATIC_CONFIG 1127 from meerschaum.utils.warnings import warn 1128 from meerschaum.utils.dtypes import value_is_null, none_if_null 1129 negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix'] 1130 try: 1131 params_json = json.dumps(params) 1132 except Exception: 1133 params_json = str(params) 1134 bad_words = ['drop ', '--', ';'] 1135 for word in bad_words: 1136 if word in params_json.lower(): 1137 warn("Aborting build_where() due to possible SQL injection.") 1138 return '' 1139 1140 query_flavor = getattr(connector, 'flavor', flavor) if connector is not None else flavor 1141 where = "" 1142 leading_and = "\n AND " 1143 for key, value in params.items(): 1144 _key = sql_item_name(key, query_flavor, None) 1145 ### search across a list (i.e. IN syntax) 1146 if isinstance(value, Iterable) and not isinstance(value, (dict, str)): 1147 includes = [ 1148 none_if_null(item) 1149 for item in value 1150 if not str(item).startswith(negation_prefix) 1151 ] 1152 null_includes = [item for item in includes if item is None] 1153 not_null_includes = [item for item in includes if item is not None] 1154 excludes = [ 1155 none_if_null(str(item)[len(negation_prefix):]) 1156 for item in value 1157 if str(item).startswith(negation_prefix) 1158 ] 1159 null_excludes = [item for item in excludes if item is None] 1160 not_null_excludes = [item for item in excludes if item is not None] 1161 1162 if includes: 1163 where += f"{leading_and}(" 1164 if not_null_includes: 1165 where += f"{_key} IN (" 1166 for item in not_null_includes: 1167 quoted_item = str(item).replace("'", "''") 1168 where += f"'{quoted_item}', " 1169 where = where[:-2] + ")" 1170 if null_includes: 1171 where += ("\n OR " if not_null_includes else "") + f"{_key} IS NULL" 1172 if includes: 1173 where += ")" 1174 1175 if excludes: 1176 where += f"{leading_and}(" 1177 if not_null_excludes: 1178 where += f"{_key} NOT IN (" 1179 for item in not_null_excludes: 1180 quoted_item = str(item).replace("'", "''") 1181 where += f"'{quoted_item}', " 1182 where = where[:-2] + ")" 1183 if null_excludes: 1184 where += ("\n AND " if not_null_excludes else "") + f"{_key} IS NOT NULL" 1185 if excludes: 1186 where += ")" 1187 1188 continue 1189 1190 ### search a dictionary 1191 elif isinstance(value, dict): 1192 import json 1193 where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'") 1194 continue 1195 1196 eq_sign = '=' 1197 is_null = 'IS NULL' 1198 if value_is_null(str(value).lstrip(negation_prefix)): 1199 value = ( 1200 (negation_prefix + 'None') 1201 if str(value).startswith(negation_prefix) 1202 else None 1203 ) 1204 if str(value).startswith(negation_prefix): 1205 value = str(value)[len(negation_prefix):] 1206 eq_sign = '!=' 1207 if value_is_null(value): 1208 value = None 1209 is_null = 'IS NOT NULL' 1210 quoted_value = str(value).replace("'", "''") 1211 where += ( 1212 f"{leading_and}{_key} " 1213 + (is_null if value is None else f"{eq_sign} '{quoted_value}'") 1214 ) 1215 1216 if len(where) > 1: 1217 where = ("\nWHERE\n " if with_where else '') + where[len(leading_and):] 1218 return where 1219 1220 1221def table_exists( 1222 table: str, 1223 connector: mrsm.connectors.sql.SQLConnector, 1224 schema: Optional[str] = None, 1225 debug: bool = False, 1226) -> bool: 1227 """Check if a table exists. 1228 1229 Parameters 1230 ---------- 1231 table: str: 1232 The name of the table in question. 1233 1234 connector: mrsm.connectors.sql.SQLConnector 1235 The connector to the database which holds the table. 1236 1237 schema: Optional[str], default None 1238 Optionally specify the table schema. 1239 Defaults to `connector.schema`. 1240 1241 debug: bool, default False : 1242 Verbosity toggle. 1243 1244 Returns 1245 ------- 1246 A `bool` indicating whether or not the table exists on the database. 1247 """ 1248 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1249 schema = schema or connector.schema 1250 insp = sqlalchemy.inspect(connector.engine) 1251 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1252 return insp.has_table(truncated_table_name, schema=schema) 1253 1254 1255def get_sqlalchemy_table( 1256 table: str, 1257 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 1258 schema: Optional[str] = None, 1259 refresh: bool = False, 1260 debug: bool = False, 1261) -> Union['sqlalchemy.Table', None]: 1262 """ 1263 Construct a SQLAlchemy table from its name. 1264 1265 Parameters 1266 ---------- 1267 table: str 1268 The name of the table on the database. Does not need to be escaped. 1269 1270 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 1271 The connector to the database which holds the table. 1272 1273 schema: Optional[str], default None 1274 Specify on which schema the table resides. 1275 Defaults to the schema set in `connector`. 1276 1277 refresh: bool, default False 1278 If `True`, rebuild the cached table object. 1279 1280 debug: bool, default False: 1281 Verbosity toggle. 1282 1283 Returns 1284 ------- 1285 A `sqlalchemy.Table` object for the table. 1286 1287 """ 1288 if connector is None: 1289 from meerschaum import get_connector 1290 connector = get_connector('sql') 1291 1292 if connector.flavor == 'duckdb': 1293 return None 1294 1295 from meerschaum.connectors.sql.tables import get_tables 1296 from meerschaum.utils.packages import attempt_import 1297 from meerschaum.utils.warnings import warn 1298 if refresh: 1299 connector.metadata.clear() 1300 tables = get_tables(mrsm_instance=connector, debug=debug, create=False) 1301 sqlalchemy = attempt_import('sqlalchemy', lazy=False) 1302 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1303 table_kwargs = { 1304 'autoload_with': connector.engine, 1305 } 1306 if schema: 1307 table_kwargs['schema'] = schema 1308 1309 if refresh or truncated_table_name not in tables: 1310 try: 1311 tables[truncated_table_name] = sqlalchemy.Table( 1312 truncated_table_name, 1313 connector.metadata, 1314 **table_kwargs 1315 ) 1316 except sqlalchemy.exc.NoSuchTableError: 1317 warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.") 1318 return None 1319 return tables[truncated_table_name] 1320 1321 1322def get_table_cols_types( 1323 table: str, 1324 connectable: Union[ 1325 'mrsm.connectors.sql.SQLConnector', 1326 'sqlalchemy.orm.session.Session', 1327 'sqlalchemy.engine.base.Engine' 1328 ], 1329 flavor: Optional[str] = None, 1330 schema: Optional[str] = None, 1331 database: Optional[str] = None, 1332 debug: bool = False, 1333) -> Dict[str, str]: 1334 """ 1335 Return a dictionary mapping a table's columns to data types. 1336 This is useful for inspecting tables creating during a not-yet-committed session. 1337 1338 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1339 Use this function if you are confident the table name is unique or if you have 1340 and explicit schema. 1341 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1342 1343 Parameters 1344 ---------- 1345 table: str 1346 The name of the table (unquoted). 1347 1348 connectable: Union[ 1349 'mrsm.connectors.sql.SQLConnector', 1350 'sqlalchemy.orm.session.Session', 1351 'sqlalchemy.engine.base.Engine' 1352 ] 1353 The connection object used to fetch the columns and types. 1354 1355 flavor: Optional[str], default None 1356 The database dialect flavor to use for the query. 1357 If omitted, default to `connectable.flavor`. 1358 1359 schema: Optional[str], default None 1360 If provided, restrict the query to this schema. 1361 1362 database: Optional[str]. default None 1363 If provided, restrict the query to this database. 1364 1365 Returns 1366 ------- 1367 A dictionary mapping column names to data types. 1368 """ 1369 import textwrap 1370 from meerschaum.connectors import SQLConnector 1371 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1372 flavor = flavor or getattr(connectable, 'flavor', None) 1373 if not flavor: 1374 raise ValueError("Please provide a database flavor.") 1375 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1376 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1377 if flavor in NO_SCHEMA_FLAVORS: 1378 schema = None 1379 if schema is None: 1380 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1381 if flavor in ('sqlite', 'duckdb', 'oracle', 'geopackage'): 1382 database = None 1383 table_trunc = truncate_item_name(table, flavor=flavor) 1384 table_lower = table.lower() 1385 table_upper = table.upper() 1386 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1387 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1388 db_prefix = ( 1389 "tempdb." 1390 if flavor == 'mssql' and table.startswith('#') 1391 else "" 1392 ) 1393 1394 def _esc(s: str) -> str: 1395 return s.replace("'", "''") 1396 1397 cols_types_query = sqlalchemy.text( 1398 textwrap.dedent(columns_types_queries.get( 1399 flavor, 1400 columns_types_queries['default'] 1401 ).format( 1402 table=_esc(table), 1403 table_trunc=_esc(table_trunc), 1404 table_lower=_esc(table_lower), 1405 table_lower_trunc=_esc(table_lower_trunc), 1406 table_upper=_esc(table_upper), 1407 table_upper_trunc=_esc(table_upper_trunc), 1408 db_prefix=db_prefix, 1409 )).lstrip().rstrip() 1410 ) 1411 1412 cols = ['database', 'schema', 'table', 'column', 'type', 'numeric_precision', 'numeric_scale'] 1413 result_cols_ix = dict(enumerate(cols)) 1414 1415 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1416 if not debug_kwargs and debug: 1417 dprint(cols_types_query) 1418 1419 try: 1420 result_rows = ( 1421 [ 1422 row 1423 for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall() 1424 ] 1425 if flavor != 'duckdb' 1426 else [ 1427 tuple([doc[col] for col in cols]) 1428 for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records') 1429 ] 1430 ) 1431 cols_types_docs = [ 1432 { 1433 result_cols_ix[i]: val 1434 for i, val in enumerate(row) 1435 } 1436 for row in result_rows 1437 ] 1438 cols_types_docs_filtered = [ 1439 doc 1440 for doc in cols_types_docs 1441 if ( 1442 ( 1443 not schema 1444 or doc['schema'] == schema 1445 ) 1446 and 1447 ( 1448 not database 1449 or doc['database'] == database 1450 ) 1451 ) 1452 ] 1453 1454 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1455 if cols_types_docs and not cols_types_docs_filtered: 1456 cols_types_docs_filtered = cols_types_docs 1457 1458 ### NOTE: Check for PostGIS GEOMETRY columns. 1459 geometry_cols_types = {} 1460 user_defined_cols = [ 1461 doc 1462 for doc in cols_types_docs_filtered 1463 if str(doc.get('type', None)).upper() == 'USER-DEFINED' 1464 ] 1465 if user_defined_cols: 1466 geometry_cols_types.update( 1467 get_postgis_geo_columns_types( 1468 connectable, 1469 table, 1470 schema=schema, 1471 debug=debug, 1472 ) 1473 ) 1474 1475 cols_types = { 1476 ( 1477 doc['column'] 1478 if flavor != 'oracle' else ( 1479 ( 1480 doc['column'].lower() 1481 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1482 else doc['column'] 1483 ) 1484 ) 1485 ): doc['type'].upper() + ( 1486 f'({precision},{scale})' 1487 if ( 1488 (precision := doc.get('numeric_precision', None)) 1489 and 1490 (scale := doc.get('numeric_scale', None)) 1491 ) 1492 else '' 1493 ) 1494 for doc in cols_types_docs_filtered 1495 } 1496 cols_types.update(geometry_cols_types) 1497 return cols_types 1498 except Exception as e: 1499 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1500 return {} 1501 1502 1503def get_table_cols_indices( 1504 table: str, 1505 connectable: Union[ 1506 'mrsm.connectors.sql.SQLConnector', 1507 'sqlalchemy.orm.session.Session', 1508 'sqlalchemy.engine.base.Engine' 1509 ], 1510 flavor: Optional[str] = None, 1511 schema: Optional[str] = None, 1512 database: Optional[str] = None, 1513 debug: bool = False, 1514) -> Dict[str, List[str]]: 1515 """ 1516 Return a dictionary mapping a table's columns to lists of indices. 1517 This is useful for inspecting tables creating during a not-yet-committed session. 1518 1519 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1520 Use this function if you are confident the table name is unique or if you have 1521 and explicit schema. 1522 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1523 1524 Parameters 1525 ---------- 1526 table: str 1527 The name of the table (unquoted). 1528 1529 connectable: Union[ 1530 'mrsm.connectors.sql.SQLConnector', 1531 'sqlalchemy.orm.session.Session', 1532 'sqlalchemy.engine.base.Engine' 1533 ] 1534 The connection object used to fetch the columns and types. 1535 1536 flavor: Optional[str], default None 1537 The database dialect flavor to use for the query. 1538 If omitted, default to `connectable.flavor`. 1539 1540 schema: Optional[str], default None 1541 If provided, restrict the query to this schema. 1542 1543 database: Optional[str]. default None 1544 If provided, restrict the query to this database. 1545 1546 Returns 1547 ------- 1548 A dictionary mapping column names to a list of indices. 1549 """ 1550 import textwrap 1551 from collections import defaultdict 1552 from meerschaum.connectors import SQLConnector 1553 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1554 flavor = flavor or getattr(connectable, 'flavor', None) 1555 if not flavor: 1556 raise ValueError("Please provide a database flavor.") 1557 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1558 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1559 if flavor in NO_SCHEMA_FLAVORS: 1560 schema = None 1561 if schema is None: 1562 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1563 if flavor in ('sqlite', 'duckdb', 'oracle', 'geopackage'): 1564 database = None 1565 table_trunc = truncate_item_name(table, flavor=flavor) 1566 table_lower = table.lower() 1567 table_upper = table.upper() 1568 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1569 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1570 db_prefix = ( 1571 "tempdb." 1572 if flavor == 'mssql' and table.startswith('#') 1573 else "" 1574 ) 1575 1576 def _esc(s: str) -> str: 1577 return s.replace("'", "''") 1578 1579 cols_indices_query = sqlalchemy.text( 1580 textwrap.dedent(columns_indices_queries.get( 1581 flavor, 1582 columns_indices_queries['default'] 1583 ).format( 1584 table=_esc(table), 1585 table_trunc=_esc(table_trunc), 1586 table_lower=_esc(table_lower), 1587 table_lower_trunc=_esc(table_lower_trunc), 1588 table_upper=_esc(table_upper), 1589 table_upper_trunc=_esc(table_upper_trunc), 1590 db_prefix=db_prefix, 1591 schema=_esc(schema) if schema else schema, 1592 )).lstrip().rstrip() 1593 ) 1594 1595 cols = ['database', 'schema', 'table', 'column', 'index', 'index_type'] 1596 if flavor == 'mssql': 1597 cols.append('clustered') 1598 result_cols_ix = dict(enumerate(cols)) 1599 1600 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1601 if not debug_kwargs and debug: 1602 dprint(cols_indices_query) 1603 1604 try: 1605 result_rows = ( 1606 [ 1607 row 1608 for row in connectable.execute(cols_indices_query, **debug_kwargs).fetchall() 1609 ] 1610 if flavor != 'duckdb' 1611 else [ 1612 tuple([doc[col] for col in cols]) 1613 for doc in connectable.read(cols_indices_query, debug=debug).to_dict(orient='records') 1614 ] 1615 ) 1616 cols_types_docs = [ 1617 { 1618 result_cols_ix[i]: val 1619 for i, val in enumerate(row) 1620 } 1621 for row in result_rows 1622 ] 1623 cols_types_docs_filtered = [ 1624 doc 1625 for doc in cols_types_docs 1626 if ( 1627 ( 1628 not schema 1629 or doc['schema'] == schema 1630 ) 1631 and 1632 ( 1633 not database 1634 or doc['database'] == database 1635 ) 1636 ) 1637 ] 1638 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1639 if cols_types_docs and not cols_types_docs_filtered: 1640 cols_types_docs_filtered = cols_types_docs 1641 1642 cols_indices = defaultdict(lambda: []) 1643 for doc in cols_types_docs_filtered: 1644 col = ( 1645 doc['column'] 1646 if flavor != 'oracle' 1647 else ( 1648 doc['column'].lower() 1649 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1650 else doc['column'] 1651 ) 1652 ) 1653 index_doc = { 1654 'name': doc.get('index', None), 1655 'type': doc.get('index_type', None) 1656 } 1657 if flavor == 'mssql': 1658 index_doc['clustered'] = doc.get('clustered', None) 1659 cols_indices[col].append(index_doc) 1660 1661 return dict(cols_indices) 1662 except Exception as e: 1663 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1664 return {} 1665 1666 1667def get_update_queries( 1668 target: str, 1669 patch: str, 1670 connectable: Union[ 1671 mrsm.connectors.sql.SQLConnector, 1672 'sqlalchemy.orm.session.Session' 1673 ], 1674 join_cols: Iterable[str], 1675 flavor: Optional[str] = None, 1676 upsert: bool = False, 1677 datetime_col: Optional[str] = None, 1678 schema: Optional[str] = None, 1679 patch_schema: Optional[str] = None, 1680 target_cols_types: Optional[Dict[str, str]] = None, 1681 patch_cols_types: Optional[Dict[str, str]] = None, 1682 identity_insert: bool = False, 1683 null_indices: bool = True, 1684 cast_columns: bool = True, 1685 debug: bool = False, 1686) -> List[str]: 1687 """ 1688 Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table. 1689 1690 Parameters 1691 ---------- 1692 target: str 1693 The name of the target table. 1694 1695 patch: str 1696 The name of the patch table. This should have the same shape as the target. 1697 1698 connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session] 1699 The `SQLConnector` or SQLAlchemy session which will later execute the queries. 1700 1701 join_cols: List[str] 1702 The columns to use to join the patch to the target. 1703 1704 flavor: Optional[str], default None 1705 If using a SQLAlchemy session, provide the expected database flavor. 1706 1707 upsert: bool, default False 1708 If `True`, return an upsert query rather than an update. 1709 1710 datetime_col: Optional[str], default None 1711 If provided, bound the join query using this column as the datetime index. 1712 This must be present on both tables. 1713 1714 schema: Optional[str], default None 1715 If provided, use this schema when quoting the target table. 1716 Defaults to `connector.schema`. 1717 1718 patch_schema: Optional[str], default None 1719 If provided, use this schema when quoting the patch table. 1720 Defaults to `schema`. 1721 1722 target_cols_types: Optional[Dict[str, Any]], default None 1723 If provided, use these as the columns-types dictionary for the target table. 1724 Default will infer from the database context. 1725 1726 patch_cols_types: Optional[Dict[str, Any]], default None 1727 If provided, use these as the columns-types dictionary for the target table. 1728 Default will infer from the database context. 1729 1730 identity_insert: bool, default False 1731 If `True`, include `SET IDENTITY_INSERT` queries before and after the update queries. 1732 Only applies for MSSQL upserts. 1733 1734 null_indices: bool, default True 1735 If `False`, do not coalesce index columns before joining. 1736 1737 cast_columns: bool, default True 1738 If `False`, do not cast update columns to the target table types. 1739 1740 debug: bool, default False 1741 Verbosity toggle. 1742 1743 Returns 1744 ------- 1745 A list of query strings to perform the update operation. 1746 """ 1747 import textwrap 1748 from meerschaum.connectors import SQLConnector 1749 from meerschaum.utils.debug import dprint 1750 from meerschaum.utils.dtypes import are_dtypes_equal 1751 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES, get_pd_type_from_db_type 1752 flavor = flavor or getattr(connectable, 'flavor', None) 1753 if not flavor: 1754 raise ValueError("Provide a flavor if using a SQLAlchemy session.") 1755 if ( 1756 flavor in ('sqlite', 'geopackage') 1757 and isinstance(connectable, SQLConnector) 1758 and connectable.db_version < '3.33.0' 1759 ): 1760 flavor = 'sqlite_delete_insert' 1761 flavor_key = (f'{flavor}-upsert' if upsert else flavor) 1762 base_queries = UPDATE_QUERIES.get( 1763 flavor_key, 1764 UPDATE_QUERIES['default'] 1765 ) 1766 if not isinstance(base_queries, list): 1767 base_queries = [base_queries] 1768 schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None) 1769 patch_schema = patch_schema or schema 1770 target_table_columns = get_table_cols_types( 1771 target, 1772 connectable, 1773 flavor=flavor, 1774 schema=schema, 1775 debug=debug, 1776 ) if not target_cols_types else target_cols_types 1777 patch_table_columns = get_table_cols_types( 1778 patch, 1779 connectable, 1780 flavor=flavor, 1781 schema=patch_schema, 1782 debug=debug, 1783 ) if not patch_cols_types else patch_cols_types 1784 1785 patch_cols_str = ', '.join( 1786 [ 1787 sql_item_name(col, flavor) 1788 for col in patch_table_columns 1789 ] 1790 ) 1791 patch_cols_prefixed_str = ', '.join( 1792 [ 1793 'p.' + sql_item_name(col, flavor) 1794 for col in patch_table_columns 1795 ] 1796 ) 1797 1798 join_cols_str = ', '.join( 1799 [ 1800 sql_item_name(col, flavor) 1801 for col in join_cols 1802 ] 1803 ) 1804 1805 value_cols = [] 1806 join_cols_types = [] 1807 if debug: 1808 dprint("target_table_columns:") 1809 mrsm.pprint(target_table_columns) 1810 for c_name, c_type in target_table_columns.items(): 1811 if c_name not in patch_table_columns: 1812 continue 1813 if flavor in DB_FLAVORS_CAST_DTYPES: 1814 c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type) 1815 ( 1816 join_cols_types 1817 if c_name in join_cols 1818 else value_cols 1819 ).append((c_name, c_type)) 1820 if debug: 1821 dprint(f"value_cols: {value_cols}") 1822 1823 if not join_cols_types: 1824 return [] 1825 if not value_cols and not upsert: 1826 return [] 1827 1828 coalesce_join_cols_str = ', '.join( 1829 [ 1830 ( 1831 ( 1832 'COALESCE(' 1833 + sql_item_name(c_name, flavor) 1834 + ', ' 1835 + get_null_replacement(c_type, flavor) 1836 + ')' 1837 ) 1838 if null_indices 1839 else sql_item_name(c_name, flavor) 1840 ) 1841 for c_name, c_type in join_cols_types 1842 ] 1843 ) 1844 1845 update_or_nothing = ('UPDATE' if value_cols else 'NOTHING') 1846 1847 def sets_subquery(l_prefix: str, r_prefix: str): 1848 if not value_cols: 1849 return '' 1850 1851 utc_value_cols = { 1852 c_name 1853 for c_name, c_type in value_cols 1854 if ('utc' in get_pd_type_from_db_type(c_type).lower()) 1855 } if flavor not in TIMEZONE_NAIVE_FLAVORS else set() 1856 1857 cast_func_cols = { 1858 c_name: ( 1859 ('', '', '') 1860 if not cast_columns or ( 1861 flavor == 'oracle' 1862 and are_dtypes_equal(get_pd_type_from_db_type(c_type), 'bytes') 1863 ) 1864 else ( 1865 ('CAST(', f" AS {c_type.replace('_', ' ')}", ')' + ( 1866 " AT TIME ZONE 'UTC'" 1867 if c_name in utc_value_cols 1868 else '' 1869 )) 1870 if flavor not in ('sqlite', 'geopackage') 1871 else ('', '', '') 1872 ) 1873 ) 1874 for c_name, c_type in value_cols 1875 } 1876 return 'SET ' + ',\n'.join([ 1877 ( 1878 l_prefix + sql_item_name(c_name, flavor, None) 1879 + ' = ' 1880 + cast_func_cols[c_name][0] 1881 + r_prefix + sql_item_name(c_name, flavor, None) 1882 + cast_func_cols[c_name][1] 1883 + cast_func_cols[c_name][2] 1884 ) for c_name, c_type in value_cols 1885 ]) 1886 1887 def and_subquery(l_prefix: str, r_prefix: str): 1888 return '\n AND\n '.join([ 1889 ( 1890 ( 1891 "COALESCE(" 1892 + l_prefix 1893 + sql_item_name(c_name, flavor, None) 1894 + ", " 1895 + get_null_replacement(c_type, flavor) 1896 + ")" 1897 + '\n =\n ' 1898 + "COALESCE(" 1899 + r_prefix 1900 + sql_item_name(c_name, flavor, None) 1901 + ", " 1902 + get_null_replacement(c_type, flavor) 1903 + ")" 1904 ) 1905 if null_indices 1906 else ( 1907 l_prefix 1908 + sql_item_name(c_name, flavor, None) 1909 + ' = ' 1910 + r_prefix 1911 + sql_item_name(c_name, flavor, None) 1912 ) 1913 ) for c_name, c_type in join_cols_types 1914 ]) 1915 1916 skip_query_val = "" 1917 target_table_name = sql_item_name(target, flavor, schema) 1918 patch_table_name = sql_item_name(patch, flavor, patch_schema) 1919 dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None 1920 date_bounds_table = patch_table_name if flavor != 'mssql' else '[date_bounds]' 1921 min_dt_col_name = f"MIN({dt_col_name})" if flavor != 'mssql' else '[Min_dt]' 1922 max_dt_col_name = f"MAX({dt_col_name})" if flavor != 'mssql' else '[Max_dt]' 1923 date_bounds_subquery = ( 1924 f"""f.{dt_col_name} >= (SELECT {min_dt_col_name} FROM {date_bounds_table}) 1925 AND 1926 f.{dt_col_name} <= (SELECT {max_dt_col_name} FROM {date_bounds_table})""" 1927 if datetime_col 1928 else "1 = 1" 1929 ) 1930 with_temp_date_bounds = f"""WITH [date_bounds] AS ( 1931 SELECT MIN({dt_col_name}) AS {min_dt_col_name}, MAX({dt_col_name}) AS {max_dt_col_name} 1932 FROM {patch_table_name} 1933 )""" if datetime_col else "" 1934 identity_insert_on = ( 1935 f"SET IDENTITY_INSERT {target_table_name} ON" 1936 if identity_insert 1937 else skip_query_val 1938 ) 1939 identity_insert_off = ( 1940 f"SET IDENTITY_INSERT {target_table_name} OFF" 1941 if identity_insert 1942 else skip_query_val 1943 ) 1944 1945 ### NOTE: MSSQL upserts must exclude the update portion if only upserting indices. 1946 when_matched_update_sets_subquery_none = "" if not value_cols else ( 1947 "\n WHEN MATCHED THEN\n" 1948 f" UPDATE {sets_subquery('', 'p.')}" 1949 ) 1950 1951 cols_equal_values = '\n,'.join( 1952 [ 1953 f"{sql_item_name(c_name, flavor)} = VALUES({sql_item_name(c_name, flavor)})" 1954 for c_name, c_type in value_cols 1955 ] 1956 ) 1957 on_duplicate_key_update = ( 1958 "ON DUPLICATE KEY UPDATE" 1959 if value_cols 1960 else "" 1961 ) 1962 ignore = "IGNORE " if not value_cols else "" 1963 1964 formatted_queries = [ 1965 textwrap.dedent(base_query.format( 1966 sets_subquery_none=sets_subquery('', 'p.'), 1967 sets_subquery_none_excluded=sets_subquery('', 'EXCLUDED.'), 1968 sets_subquery_f=sets_subquery('f.', 'p.'), 1969 and_subquery_f=and_subquery('p.', 'f.'), 1970 and_subquery_t=and_subquery('p.', 't.'), 1971 target_table_name=target_table_name, 1972 patch_table_name=patch_table_name, 1973 patch_cols_str=patch_cols_str, 1974 patch_cols_prefixed_str=patch_cols_prefixed_str, 1975 date_bounds_subquery=date_bounds_subquery, 1976 join_cols_str=join_cols_str, 1977 coalesce_join_cols_str=coalesce_join_cols_str, 1978 update_or_nothing=update_or_nothing, 1979 when_matched_update_sets_subquery_none=when_matched_update_sets_subquery_none, 1980 cols_equal_values=cols_equal_values, 1981 on_duplicate_key_update=on_duplicate_key_update, 1982 ignore=ignore, 1983 with_temp_date_bounds=with_temp_date_bounds, 1984 identity_insert_on=identity_insert_on, 1985 identity_insert_off=identity_insert_off, 1986 )).lstrip().rstrip() 1987 for base_query in base_queries 1988 ] 1989 1990 ### NOTE: Allow for skipping some queries. 1991 return [query for query in formatted_queries if query] 1992 1993 1994def get_null_replacement(typ: str, flavor: str) -> str: 1995 """ 1996 Return a value that may temporarily be used in place of NULL for this type. 1997 1998 Parameters 1999 ---------- 2000 typ: str 2001 The typ to be converted to NULL. 2002 2003 flavor: str 2004 The database flavor for which this value will be used. 2005 2006 Returns 2007 ------- 2008 A value which may stand in place of NULL for this type. 2009 `'None'` is returned if a value cannot be determined. 2010 """ 2011 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 2012 if 'geometry' in typ.lower(): 2013 return "'010100000000008058346FCDC100008058346FCDC1'" 2014 if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'): 2015 return '-987654321' 2016 if 'bool' in typ.lower() or typ.lower() == 'bit': 2017 bool_typ = ( 2018 PD_TO_DB_DTYPES_FLAVORS 2019 .get('bool', {}) 2020 .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default']) 2021 ) 2022 if flavor in DB_FLAVORS_CAST_DTYPES: 2023 bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ) 2024 val_to_cast = ( 2025 -987654321 2026 if flavor in ('mysql', 'mariadb') 2027 else 0 2028 ) 2029 return f'CAST({val_to_cast} AS {bool_typ})' 2030 if 'time' in typ.lower() or 'date' in typ.lower(): 2031 db_type = typ if typ.isupper() else None 2032 return dateadd_str(flavor=flavor, begin='1900-01-01', db_type=db_type) 2033 if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',): 2034 return '-987654321.0' 2035 if flavor == 'oracle' and typ.lower().split('(', maxsplit=1)[0] == 'char': 2036 return "'-987654321'" 2037 if flavor == 'oracle' and typ.lower() in ('blob', 'bytes'): 2038 return '00' 2039 if typ.lower() in ('uniqueidentifier', 'guid', 'uuid'): 2040 magic_val = 'DEADBEEF-ABBA-BABE-CAFE-DECAFC0FFEE5' 2041 if flavor == 'mssql': 2042 return f"CAST('{magic_val}' AS UNIQUEIDENTIFIER)" 2043 return f"'{magic_val}'" 2044 return ('n' if flavor == 'oracle' else '') + "'-987654321'" 2045 2046 2047def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]: 2048 """ 2049 Fetch the database version if possible. 2050 """ 2051 version_name = sql_item_name('version', conn.flavor, None) 2052 version_query = version_queries.get( 2053 conn.flavor, 2054 version_queries['default'] 2055 ).format(version_name=version_name) 2056 return conn.value(version_query, debug=debug) 2057 2058 2059def get_rename_table_queries( 2060 old_table: str, 2061 new_table: str, 2062 flavor: str, 2063 schema: Optional[str] = None, 2064 new_schema: Optional[str] = None, 2065) -> List[str]: 2066 """ 2067 Return queries to alter a table's name. 2068 2069 Parameters 2070 ---------- 2071 old_table: str 2072 The unquoted name of the old table. 2073 2074 new_table: str 2075 The unquoted name of the new table. 2076 2077 flavor: str 2078 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 2079 2080 schema: Optional[str], default None 2081 The schema on which the table resides. 2082 2083 new_schema: Optional[str], default None 2084 If provided, move the table to this schema. 2085 Defaults to `schema`. 2086 2087 Returns 2088 ------- 2089 A list of `ALTER TABLE` or equivalent queries for the database flavor. 2090 """ 2091 if new_schema is None: 2092 new_schema = schema 2093 2094 if flavor in ('sqlite', 'geopackage'): 2095 schema = None 2096 new_schema = None 2097 2098 old_table_name = sql_item_name(old_table, flavor, schema) 2099 new_table_name = sql_item_name(new_table, flavor, None) 2100 2101 if flavor == 'mssql': 2102 queries = [] 2103 if schema != new_schema: 2104 queries.append(f"ALTER SCHEMA {sql_item_name(new_schema, flavor)} TRANSFER {old_table_name}") 2105 schema = new_schema 2106 old_table_name = sql_item_name(old_table, flavor, schema) 2107 if old_table != new_table: 2108 queries.append(f"EXEC sp_rename '{old_table_name}', '{new_table}'") 2109 return queries 2110 2111 if flavor in ('mysql', 'mariadb'): 2112 new_table_qualified = sql_item_name(new_table, flavor, new_schema) 2113 return [f"RENAME TABLE {old_table_name} TO {new_table_qualified}"] 2114 2115 if_exists_str = "IF EXISTS" if flavor in DROP_IF_EXISTS_FLAVORS else "" 2116 if flavor == 'duckdb': 2117 tmp_table = '_tmp_rename_' + new_table 2118 tmp_table_name = sql_item_name(tmp_table, flavor, schema) 2119 return ( 2120 get_create_table_queries( 2121 f"SELECT * FROM {old_table_name}", 2122 tmp_table, 2123 'duckdb', 2124 schema, 2125 ) + get_create_table_queries( 2126 f"SELECT * FROM {tmp_table_name}", 2127 new_table, 2128 'duckdb', 2129 new_schema, 2130 ) + [ 2131 f"DROP TABLE {if_exists_str} {tmp_table_name}", 2132 f"DROP TABLE {if_exists_str} {old_table_name}", 2133 ] 2134 ) 2135 2136 queries = [] 2137 if schema != new_schema: 2138 if flavor in ('postgresql', 'postgis', 'timescaledb', 'timescaledb-ha', 'citus', 'cockroachdb'): 2139 queries.append(f"ALTER TABLE {old_table_name} SET SCHEMA {sql_item_name(new_schema, flavor)}") 2140 schema = new_schema 2141 old_table_name = sql_item_name(old_table, flavor, schema) 2142 2143 if old_table != new_table: 2144 queries.append(f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}") 2145 2146 return queries 2147 2148 2149def get_create_table_query( 2150 query_or_dtypes: Union[str, Dict[str, str]], 2151 new_table: str, 2152 flavor: str, 2153 schema: Optional[str] = None, 2154) -> str: 2155 """ 2156 NOTE: This function is deprecated. Use `get_create_table_queries()` instead. 2157 2158 Return a query to create a new table from a `SELECT` query. 2159 2160 Parameters 2161 ---------- 2162 query: Union[str, Dict[str, str]] 2163 The select query to use for the creation of the table. 2164 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 2165 2166 new_table: str 2167 The unquoted name of the new table. 2168 2169 flavor: str 2170 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 2171 2172 schema: Optional[str], default None 2173 The schema on which the table will reside. 2174 2175 Returns 2176 ------- 2177 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 2178 """ 2179 return get_create_table_queries( 2180 query_or_dtypes, 2181 new_table, 2182 flavor, 2183 schema=schema, 2184 primary_key=None, 2185 )[0] 2186 2187 2188def get_create_table_queries( 2189 query_or_dtypes: Union[str, Dict[str, str]], 2190 new_table: str, 2191 flavor: str, 2192 schema: Optional[str] = None, 2193 primary_key: Optional[str] = None, 2194 primary_key_db_type: Optional[str] = None, 2195 autoincrement: bool = False, 2196 datetime_column: Optional[str] = None, 2197 _parse_dtypes: bool = True, 2198) -> List[str]: 2199 """ 2200 Return a query to create a new table from a `SELECT` query or a `dtypes` dictionary. 2201 2202 Parameters 2203 ---------- 2204 query_or_dtypes: Union[str, Dict[str, str]] 2205 The select query to use for the creation of the table. 2206 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 2207 2208 new_table: str 2209 The unquoted name of the new table. 2210 2211 flavor: str 2212 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 2213 2214 schema: Optional[str], default None 2215 The schema on which the table will reside. 2216 2217 primary_key: Optional[str], default None 2218 If provided, designate this column as the primary key in the new table. 2219 2220 primary_key_db_type: Optional[str], default None 2221 If provided, alter the primary key to this type (to set NOT NULL constraint). 2222 2223 autoincrement: bool, default False 2224 If `True` and `primary_key` is provided, create the `primary_key` column 2225 as an auto-incrementing integer column. 2226 2227 datetime_column: Optional[str], default None 2228 If provided, include this column in the primary key. 2229 Applicable to TimescaleDB only. 2230 2231 _parse_dtypes: bool, default True 2232 If `True`, cast Pandas dtypes to SQL dtypes. 2233 Otherwise pass through the given value directly. 2234 2235 Returns 2236 ------- 2237 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 2238 """ 2239 if not isinstance(query_or_dtypes, (str, dict)): 2240 raise TypeError("`query_or_dtypes` must be a query or a dtypes dictionary.") 2241 2242 method = ( 2243 _get_create_table_query_from_cte 2244 if isinstance(query_or_dtypes, str) 2245 else _get_create_table_query_from_dtypes 2246 ) 2247 return method( 2248 query_or_dtypes, 2249 new_table, 2250 flavor, 2251 schema=schema, 2252 primary_key=primary_key, 2253 primary_key_db_type=primary_key_db_type, 2254 autoincrement=(autoincrement and flavor not in SKIP_AUTO_INCREMENT_FLAVORS), 2255 datetime_column=datetime_column, 2256 _parse_dtypes=_parse_dtypes, 2257 ) 2258 2259 2260def _get_create_table_query_from_dtypes( 2261 dtypes: Dict[str, str], 2262 new_table: str, 2263 flavor: str, 2264 schema: Optional[str] = None, 2265 primary_key: Optional[str] = None, 2266 primary_key_db_type: Optional[str] = None, 2267 autoincrement: bool = False, 2268 datetime_column: Optional[str] = None, 2269 _parse_dtypes: bool = True, 2270) -> List[str]: 2271 """ 2272 Create a new table from a `dtypes` dictionary. 2273 """ 2274 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, AUTO_INCREMENT_COLUMN_FLAVORS 2275 if not dtypes and not primary_key: 2276 raise ValueError(f"Expecting columns for table '{new_table}'.") 2277 2278 if flavor in SKIP_AUTO_INCREMENT_FLAVORS: 2279 autoincrement = False 2280 2281 cols_types = ( 2282 [ 2283 ( 2284 primary_key, 2285 get_db_type_from_pd_type(dtypes.get(primary_key, 'int') or 'int', flavor=flavor) 2286 ) if _parse_dtypes else ( 2287 primary_key, 2288 dtypes.get(primary_key, 'INT') or 'INT' 2289 ) 2290 ] 2291 if primary_key 2292 else [] 2293 ) + [ 2294 ( 2295 col, 2296 ( 2297 get_db_type_from_pd_type(typ, flavor=flavor) 2298 if _parse_dtypes 2299 else typ 2300 ) 2301 ) 2302 for col, typ in dtypes.items() 2303 if col != primary_key 2304 ] 2305 2306 table_name = sql_item_name(new_table, schema=schema, flavor=flavor) 2307 primary_key_name = sql_item_name(primary_key, flavor) if primary_key else None 2308 primary_key_constraint_name = ( 2309 sql_item_name(f'PK_{new_table}', flavor, None) 2310 if primary_key 2311 else None 2312 ) 2313 datetime_column_name = sql_item_name(datetime_column, flavor) if datetime_column else None 2314 primary_key_clustered = ( 2315 "CLUSTERED" 2316 if not datetime_column or datetime_column == primary_key 2317 else "NONCLUSTERED" 2318 ) 2319 query = f"CREATE TABLE {table_name} (" 2320 if primary_key: 2321 col_db_type = cols_types[0][1] 2322 auto_increment_str = (' ' + AUTO_INCREMENT_COLUMN_FLAVORS.get( 2323 flavor, 2324 AUTO_INCREMENT_COLUMN_FLAVORS['default'] 2325 )) if autoincrement or primary_key not in dtypes else '' 2326 col_name = sql_item_name(primary_key, flavor=flavor, schema=None) 2327 2328 if flavor in ('sqlite', 'geopackage'): 2329 query += ( 2330 f"\n {col_name} " 2331 + (f"{col_db_type}" if not auto_increment_str else 'INTEGER') 2332 + f" PRIMARY KEY{auto_increment_str} NOT NULL," 2333 ) 2334 elif flavor == 'oracle': 2335 query += f"\n {col_name} {col_db_type} {auto_increment_str} PRIMARY KEY," 2336 elif ( 2337 flavor in ('timescaledb', 'timescaledb-ha') 2338 and datetime_column 2339 and datetime_column != primary_key 2340 ): 2341 query += f"\n {col_name} {col_db_type}{auto_increment_str} NOT NULL," 2342 elif flavor == 'mssql': 2343 query += f"\n {col_name} {col_db_type}{auto_increment_str} NOT NULL," 2344 else: 2345 query += f"\n {col_name} {col_db_type} PRIMARY KEY{auto_increment_str} NOT NULL," 2346 2347 for col, db_type in cols_types: 2348 if col == primary_key: 2349 continue 2350 col_name = sql_item_name(col, schema=None, flavor=flavor) 2351 query += f"\n {col_name} {db_type}," 2352 if ( 2353 flavor in ('timescaledb', 'timescaledb-ha') 2354 and datetime_column 2355 and primary_key 2356 and datetime_column != primary_key 2357 ): 2358 query += f"\n PRIMARY KEY({datetime_column_name}, {primary_key_name})," 2359 2360 if flavor == 'mssql' and primary_key: 2361 query += f"\n CONSTRAINT {primary_key_constraint_name} PRIMARY KEY {primary_key_clustered} ({primary_key_name})," 2362 2363 query = query[:-1] 2364 query += "\n)" 2365 2366 queries = [query] 2367 return queries 2368 2369 2370def _get_create_table_query_from_cte( 2371 query: str, 2372 new_table: str, 2373 flavor: str, 2374 schema: Optional[str] = None, 2375 primary_key: Optional[str] = None, 2376 primary_key_db_type: Optional[str] = None, 2377 autoincrement: bool = False, 2378 datetime_column: Optional[str] = None, 2379 _parse_dtypes=None, 2380) -> List[str]: 2381 """ 2382 Create a new table from a CTE query. 2383 """ 2384 import textwrap 2385 create_cte = 'create_query' 2386 create_cte_name = sql_item_name(create_cte, flavor, None) 2387 new_table_name = sql_item_name(new_table, flavor, schema) 2388 primary_key_constraint_name = ( 2389 sql_item_name(f'PK_{new_table}', flavor, None) 2390 if primary_key 2391 else None 2392 ) 2393 primary_key_name = ( 2394 sql_item_name(primary_key, flavor, None) 2395 if primary_key 2396 else None 2397 ) 2398 primary_key_clustered = ( 2399 "CLUSTERED" 2400 if not datetime_column or datetime_column == primary_key 2401 else "NONCLUSTERED" 2402 ) 2403 datetime_column_name = ( 2404 sql_item_name(datetime_column, flavor) 2405 if datetime_column 2406 else None 2407 ) 2408 query = query.lstrip() 2409 is_with_query = query.lower().startswith('with ') 2410 if flavor in ('mssql',): 2411 if is_with_query: 2412 final_select_ix = query.lower().rfind('select') 2413 create_table_queries = [ 2414 ( 2415 query[:final_select_ix].rstrip() + ',\n' 2416 + f"{create_cte_name} AS (\n" 2417 + textwrap.indent(query[final_select_ix:], ' ') 2418 + "\n)\n" 2419 + f"SELECT *\nINTO {new_table_name}\nFROM {create_cte_name}" 2420 ), 2421 ] 2422 else: 2423 create_table_queries = [ 2424 ( 2425 "SELECT *\n" 2426 f"INTO {new_table_name}\n" 2427 f"FROM (\n{textwrap.indent(query, ' ')}\n) AS {create_cte_name}" 2428 ), 2429 ] 2430 2431 alter_type_queries = [] 2432 if primary_key_db_type: 2433 alter_type_queries.extend([ 2434 ( 2435 f"ALTER TABLE {new_table_name}\n" 2436 f"ALTER COLUMN {primary_key_name} {primary_key_db_type} NOT NULL" 2437 ), 2438 ]) 2439 alter_type_queries.extend([ 2440 ( 2441 f"ALTER TABLE {new_table_name}\n" 2442 f"ADD CONSTRAINT {primary_key_constraint_name} " 2443 f"PRIMARY KEY {primary_key_clustered} ({primary_key_name})" 2444 ), 2445 ]) 2446 elif flavor in (None,): 2447 create_table_queries = [ 2448 ( 2449 f"WITH {create_cte_name} AS (\n{textwrap.indent(query, ' ')}\n)\n" 2450 f"CREATE TABLE {new_table_name} AS\n" 2451 "SELECT *\n" 2452 f"FROM {create_cte_name}" 2453 ), 2454 ] 2455 2456 alter_type_queries = [ 2457 ( 2458 f"ALTER TABLE {new_table_name}\n" 2459 f"ADD PRIMARY KEY ({primary_key_name})" 2460 ), 2461 ] 2462 elif flavor in ('sqlite', 'geopackage'): 2463 if is_with_query: 2464 create_table_queries = [ 2465 f"CREATE TABLE {new_table_name} AS\n{query}" 2466 ] 2467 else: 2468 create_table_queries = [ 2469 ( 2470 f"CREATE TABLE {new_table_name} AS\n" 2471 "SELECT *\n" 2472 f"FROM (\n{textwrap.indent(query, ' ')}\n)" 2473 ), 2474 ] 2475 ### SQLite does not support ALTER TABLE ... ADD PRIMARY KEY. 2476 alter_type_queries = [] 2477 elif flavor in ( 2478 'mysql', 'mariadb', 'duckdb', 'oracle', 2479 'postgresql', 'postgis', 'timescaledb', 'timescaledb-ha', 'citus', 'cockroachdb' 2480 ): 2481 if is_with_query: 2482 create_table_queries = [ 2483 f"CREATE TABLE {new_table_name} AS\n{query}" 2484 ] 2485 else: 2486 create_table_queries = [ 2487 ( 2488 f"CREATE TABLE {new_table_name} AS\n" 2489 "SELECT *\n" 2490 f"FROM (\n{textwrap.indent(query, ' ')}\n)" 2491 + (f" AS {create_cte_name}" if flavor != 'oracle' else '') 2492 ), 2493 ] 2494 2495 pk_cols_str = ( 2496 f"{datetime_column_name}, {primary_key_name}" 2497 if ( 2498 flavor in ('timescaledb', 'timescaledb-ha') 2499 and datetime_column 2500 and datetime_column != primary_key 2501 ) 2502 else f"{primary_key_name}" 2503 ) 2504 alter_type_queries = [ 2505 ( 2506 f"ALTER TABLE {new_table_name}\n" 2507 f"ADD PRIMARY KEY ({pk_cols_str})" 2508 ), 2509 ] 2510 else: 2511 if is_with_query: 2512 create_table_queries = [ 2513 ( 2514 "SELECT *\n" 2515 f"INTO {new_table_name}\n" 2516 f"FROM (\n{textwrap.indent(query, ' ')}\n) AS {create_cte_name}" 2517 ), 2518 ] 2519 else: 2520 create_table_queries = [ 2521 ( 2522 "SELECT *\n" 2523 f"INTO {new_table_name}\n" 2524 f"FROM (\n{textwrap.indent(query, ' ')}\n) AS {create_cte_name}" 2525 ), 2526 ] 2527 2528 alter_type_queries = [ 2529 ( 2530 f"ALTER TABLE {new_table_name}\n" 2531 f"ADD PRIMARY KEY ({primary_key_name})" 2532 ), 2533 ] 2534 2535 if not primary_key: 2536 return create_table_queries 2537 2538 return create_table_queries + alter_type_queries 2539 2540 2541def wrap_query_with_cte( 2542 sub_query: str, 2543 parent_query: str, 2544 flavor: str, 2545 cte_name: str = "src", 2546) -> str: 2547 """ 2548 Wrap a subquery in a CTE and append an encapsulating query. 2549 2550 Parameters 2551 ---------- 2552 sub_query: str 2553 The query to be referenced. This may itself contain CTEs. 2554 Unless `cte_name` is provided, this will be aliased as `src`. 2555 2556 parent_query: str 2557 The larger query to append which references the subquery. 2558 This must not contain CTEs. 2559 2560 flavor: str 2561 The database flavor, e.g. `'mssql'`. 2562 2563 cte_name: str, default 'src' 2564 The CTE alias, defaults to `src`. 2565 2566 Returns 2567 ------- 2568 An encapsulating query which allows you to treat `sub_query` as a temporary table. 2569 2570 Examples 2571 -------- 2572 2573 >>> from meerschaum.utils.sql import wrap_query_with_cte 2574 >>> sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo" 2575 >>> parent_query = "SELECT newval * 3 FROM src" 2576 >>> query = wrap_query_with_cte(sub_query, parent_query, 'mssql') 2577 >>> print(query) 2578 >>> # WITH foo AS (SELECT 1 AS val), 2579 >>> # [src] AS ( 2580 >>> # SELECT (val * 2) AS newval FROM foo 2581 >>> # ) 2582 >>> # SELECT newval * 3 FROM src 2583 2584 """ 2585 import textwrap 2586 sub_query = sub_query.lstrip() 2587 cte_name_quoted = sql_item_name(cte_name, flavor, None) 2588 2589 if flavor in NO_CTE_FLAVORS: 2590 return ( 2591 parent_query 2592 .replace(cte_name_quoted, '--MRSM_SUBQUERY--') 2593 .replace(cte_name, '--MRSM_SUBQUERY--') 2594 .replace('--MRSM_SUBQUERY--', f"(\n{sub_query}\n) AS {cte_name_quoted}") 2595 ) 2596 2597 if sub_query.lstrip().lower().startswith('with '): 2598 depth = 0 2599 select_index = -1 2600 sq_lower = sub_query.lower() 2601 2602 # Iterate through the query to find the first 'SELECT' at the top level (depth 0) 2603 # Start searching after the 'WITH' keyword 2604 start_search = sq_lower.find('with') + 4 2605 2606 for i in range(start_search, len(sq_lower)): 2607 char = sq_lower[i] 2608 if char == '(': 2609 depth += 1 2610 elif char == ')': 2611 depth -= 1 2612 elif depth == 0: 2613 # Check for 'SELECT' at a word boundary 2614 if sq_lower[i:i+6] == 'select': 2615 # Ensure it's not part of another word (e.g., 'selection') 2616 # by checking the character immediately following 'select' 2617 is_bound = (i + 6 == len(sq_lower)) or (not sq_lower[i+6].isalnum()) 2618 if is_bound: 2619 select_index = i 2620 break 2621 2622 # If we found the main SELECT, we slice and flatten. 2623 # Part 1 (definitions) contains the 'WITH cte AS (...),' 2624 # Part 2 (body) contains the 'SELECT ... UNION ALL ...' 2625 if select_index != -1: 2626 definitions = sub_query[:select_index].rstrip() 2627 # If the definitions end in a comma (rare but possible), remove it 2628 if definitions.endswith(','): 2629 definitions = definitions[:-1].rstrip() 2630 2631 body = sub_query[select_index:].strip() 2632 2633 return ( 2634 f"{definitions},\n" 2635 f"{cte_name_quoted} AS (\n" 2636 f"{textwrap.indent(body, ' ')}\n" 2637 f")\n{parent_query}" 2638 ) 2639 2640 return ( 2641 f"WITH {cte_name_quoted} AS (\n" 2642 f"{textwrap.indent(sub_query, ' ')}\n" 2643 f")\n{parent_query}" 2644 ) 2645 2646 2647def format_cte_subquery( 2648 sub_query: str, 2649 flavor: str, 2650 sub_name: str = 'src', 2651 cols_to_select: Union[List[str], str] = '*', 2652) -> str: 2653 """ 2654 Given a subquery, build a wrapper query that selects from the CTE subquery. 2655 2656 Parameters 2657 ---------- 2658 sub_query: str 2659 The subquery to wrap. 2660 2661 flavor: str 2662 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 2663 2664 sub_name: str, default 'src' 2665 If possible, give this name to the CTE (must be unquoted). 2666 2667 cols_to_select: Union[List[str], str], default '' 2668 If specified, choose which columns to select from the CTE. 2669 If a list of strings is provided, each item will be quoted and joined with commas. 2670 If a string is given, assume it is quoted and insert it into the query. 2671 2672 Returns 2673 ------- 2674 A wrapper query that selects from the CTE. 2675 """ 2676 quoted_sub_name = sql_item_name(sub_name, flavor, None) 2677 cols_str = ( 2678 cols_to_select 2679 if isinstance(cols_to_select, str) 2680 else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select]) 2681 ) 2682 parent_query = ( 2683 f"SELECT {cols_str}\n" 2684 f"FROM {quoted_sub_name}" 2685 ) 2686 return wrap_query_with_cte(sub_query, parent_query, flavor, cte_name=sub_name) 2687 2688 2689def session_execute( 2690 session: 'sqlalchemy.orm.session.Session', 2691 queries: Union[List[str], str], 2692 with_results: bool = False, 2693 debug: bool = False, 2694) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]: 2695 """ 2696 Similar to `SQLConnector.exec_queries()`, execute a list of queries 2697 and roll back when one fails. 2698 2699 Parameters 2700 ---------- 2701 session: sqlalchemy.orm.session.Session 2702 A SQLAlchemy session representing a transaction. 2703 2704 queries: Union[List[str], str] 2705 A query or list of queries to be executed. 2706 If a query fails, roll back the session. 2707 2708 with_results: bool, default False 2709 If `True`, return a list of result objects. 2710 2711 Returns 2712 ------- 2713 A `SuccessTuple` indicating the queries were successfully executed. 2714 If `with_results`, return the `SuccessTuple` and a list of results. 2715 """ 2716 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 2717 if not isinstance(queries, list): 2718 queries = [queries] 2719 successes, msgs, results = [], [], [] 2720 for query in queries: 2721 if debug: 2722 dprint(query) 2723 query_text = sqlalchemy.text(query) 2724 fail_msg = "Failed to execute queries." 2725 try: 2726 result = session.execute(query_text) 2727 query_success = result is not None 2728 query_msg = "Success" if query_success else fail_msg 2729 except Exception as e: 2730 query_success = False 2731 query_msg = f"{fail_msg}\n{e}" 2732 result = None 2733 successes.append(query_success) 2734 msgs.append(query_msg) 2735 results.append(result) 2736 if not query_success: 2737 if debug: 2738 dprint("Rolling back session.") 2739 session.rollback() 2740 break 2741 success, msg = all(successes), '\n'.join(msgs) 2742 if with_results: 2743 return (success, msg), results 2744 return success, msg 2745 2746 2747def get_reset_autoincrement_queries( 2748 table: str, 2749 column: str, 2750 connector: mrsm.connectors.SQLConnector, 2751 schema: Optional[str] = None, 2752 debug: bool = False, 2753) -> List[str]: 2754 """ 2755 Return a list of queries to reset a table's auto-increment counter to the next largest value. 2756 2757 Parameters 2758 ---------- 2759 table: str 2760 The name of the table on which the auto-incrementing column exists. 2761 2762 column: str 2763 The name of the auto-incrementing column. 2764 2765 connector: mrsm.connectors.SQLConnector 2766 The SQLConnector to the database on which the table exists. 2767 2768 schema: Optional[str], default None 2769 The schema of the table. Defaults to `connector.schema`. 2770 2771 Returns 2772 ------- 2773 A list of queries to be executed to reset the auto-incrementing column. 2774 """ 2775 if not table_exists(table, connector, schema=schema, debug=debug): 2776 return [] 2777 2778 schema = schema or connector.schema 2779 max_id_name = sql_item_name('max_id', connector.flavor) 2780 table_name = sql_item_name(table, connector.flavor, schema) 2781 table_seq_name = sql_item_name(table + '_' + column + '_seq', connector.flavor, schema) 2782 column_name = sql_item_name(column, connector.flavor) 2783 max_id = connector.value( 2784 f""" 2785 SELECT COALESCE(MAX({column_name}), 0) AS {max_id_name} 2786 FROM {table_name} 2787 """, 2788 debug=debug, 2789 ) 2790 if max_id is None: 2791 return [] 2792 2793 reset_queries = reset_autoincrement_queries.get( 2794 connector.flavor, 2795 reset_autoincrement_queries['default'] 2796 ) 2797 if not isinstance(reset_queries, list): 2798 reset_queries = [reset_queries] 2799 2800 return [ 2801 query.format( 2802 column=column, 2803 column_name=column_name, 2804 table=table, 2805 table_name=table_name, 2806 table_seq_name=table_seq_name, 2807 val=max_id, 2808 val_plus_1=(max_id + 1), 2809 ) 2810 for query in reset_queries 2811 ] 2812 2813 2814def get_postgis_geo_columns_types( 2815 connectable: Union[ 2816 'mrsm.connectors.sql.SQLConnector', 2817 'sqlalchemy.orm.session.Session', 2818 'sqlalchemy.engine.base.Engine' 2819 ], 2820 table: str, 2821 schema: Optional[str] = 'public', 2822 debug: bool = False, 2823) -> Dict[str, str]: 2824 """ 2825 Return a dictionary mapping PostGIS geometry column names to geometry types. 2826 2827 Parameters 2828 ---------- 2829 connectable: Union[SQLConnector, sqlalchemy.orm.session.Session, sqlalchemy.engine.base.Engine] 2830 The SQLConnector, Session, or Engine. 2831 2832 table: str 2833 The table's name. 2834 2835 schema: Optional[str], default 'public' 2836 The schema to use for the fully qualified table name. 2837 2838 Returns 2839 ------- 2840 A dictionary mapping column names to geometry types. 2841 """ 2842 from meerschaum.utils.dtypes import get_geometry_type_srid 2843 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 2844 default_type, default_srid = get_geometry_type_srid() 2845 default_type = default_type.upper() 2846 2847 clean(table) 2848 clean(str(schema)) 2849 schema = schema or 'public' 2850 truncated_schema_name = truncate_item_name(schema, flavor='postgis') 2851 truncated_table_name = truncate_item_name(table, flavor='postgis') 2852 query = sqlalchemy.text( 2853 "SELECT \"f_geometry_column\" AS \"column\", 'GEOMETRY' AS \"func\", \"type\", geom.\"srid\"\n" 2854 "FROM \"geometry_columns\" AS geom\n" 2855 f"WHERE \"f_table_schema\" = '{truncated_schema_name}'\n" 2856 f" AND \"f_table_name\" = '{truncated_table_name}'\n" 2857 "UNION ALL\n" 2858 "SELECT \"f_geography_column\" AS \"column\", 'GEOGRAPHY' AS \"func\", \"type\", geog.\"srid\"\n" 2859 "FROM \"geography_columns\" AS geog\n" 2860 f"WHERE \"f_table_schema\" = '{truncated_schema_name}'\n" 2861 f" AND \"f_table_name\" = '{truncated_table_name}'\n" 2862 ) 2863 debug_kwargs = {'debug': debug} if isinstance(connectable, mrsm.connectors.SQLConnector) else {} 2864 result_rows = [ 2865 row 2866 for row in connectable.execute(query, **debug_kwargs).fetchall() 2867 ] 2868 cols_type_tuples = { 2869 row[0]: ( 2870 row[1], # func 2871 row[2], # type 2872 ( 2873 row[3] 2874 if row[3] 2875 else 0 2876 ) # srid 2877 ) 2878 for row in result_rows 2879 } 2880 2881 geometry_cols_types = { 2882 col: ( 2883 f"{func}({typ.upper()}, {srid})" 2884 if srid 2885 else ( 2886 func 2887 + ( 2888 f'({typ.upper()})' 2889 if typ.upper() not in ('GEOMETRY', 'GEOGRAPHY') 2890 else '' 2891 ) 2892 ) 2893 ) 2894 for col, (func, typ, srid) in cols_type_tuples.items() 2895 } 2896 return geometry_cols_types 2897 2898 2899def get_create_schema_if_not_exists_queries( 2900 schema: str, 2901 flavor: str, 2902) -> List[str]: 2903 """ 2904 Return the queries to create a schema if it does not yet exist. 2905 For databases which do not support schemas, an empty list will be returned. 2906 """ 2907 if not schema: 2908 return [] 2909 2910 if flavor in NO_SCHEMA_FLAVORS: 2911 return [] 2912 2913 if schema == DEFAULT_SCHEMA_FLAVORS.get(flavor, None): 2914 return [] 2915 2916 clean(schema) 2917 2918 if flavor == 'mssql': 2919 return [ 2920 ( 2921 f"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{schema}')\n" 2922 "BEGIN\n" 2923 f" EXEC('CREATE SCHEMA {sql_item_name(schema, flavor)}');\n" 2924 "END;" 2925 ) 2926 ] 2927 2928 if flavor == 'oracle': 2929 return [] 2930 2931 return [ 2932 f"CREATE SCHEMA IF NOT EXISTS {sql_item_name(schema, flavor)};" 2933 ]
669def clean(substring: str) -> None: 670 """ 671 Ensure a substring is clean enough to be inserted into a SQL query. 672 Raises an exception when banned words are used. 673 """ 674 from meerschaum.utils.warnings import error 675 banned_symbols = [ 676 ';', '--', 'drop ', '/*', '*/', 677 'union ', 'exec ', 'execute ', 678 'insert ', 'update ', 'delete ', 679 'create ', 'alter ', 'truncate ', 680 ] 681 for symbol in banned_symbols: 682 if symbol in str(substring).lower(): 683 error(f"Invalid string: '{substring}'")
Ensure a substring is clean enough to be inserted into a SQL query. Raises an exception when banned words are used.
689def dateadd_str( 690 flavor: str = 'postgresql', 691 datepart: str = 'day', 692 number: Union[int, float] = 0, 693 begin: Union[str, datetime, int] = 'now', 694 db_type: Optional[str] = None, 695) -> str: 696 """ 697 Generate a `DATEADD` clause depending on database flavor. 698 699 Parameters 700 ---------- 701 flavor: str, default `'postgresql'` 702 SQL database flavor, e.g. `'postgresql'`, `'sqlite'`. 703 704 Currently supported flavors: 705 706 - `'postgresql'` 707 - `'postgis'` 708 - `'timescaledb'` 709 - `'timescaledb-ha'` 710 - `'citus'` 711 - `'cockroachdb'` 712 - `'duckdb'` 713 - `'mssql'` 714 - `'mysql'` 715 - `'mariadb'` 716 - `'sqlite'` 717 - `'geopackage'` 718 - `'oracle'` 719 720 datepart: str, default `'day'` 721 Which part of the date to modify. Supported values: 722 723 - `'year'` 724 - `'month'` 725 - `'day'` 726 - `'hour'` 727 - `'minute'` 728 - `'second'` 729 730 number: Union[int, float], default `0` 731 How many units to add to the date part. 732 733 begin: Union[str, datetime], default `'now'` 734 Base datetime to which to add dateparts. 735 736 db_type: Optional[str], default None 737 If provided, cast the datetime string as the type. 738 Otherwise, infer this from the input datetime value. 739 740 Returns 741 ------- 742 The appropriate `DATEADD` string for the corresponding database flavor. 743 744 Examples 745 -------- 746 >>> dateadd_str( 747 ... flavor='mssql', 748 ... begin=datetime(2022, 1, 1, 0, 0), 749 ... number=1, 750 ... ) 751 "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))" 752 >>> dateadd_str( 753 ... flavor='postgresql', 754 ... begin=datetime(2022, 1, 1, 0, 0), 755 ... number=1, 756 ... ) 757 "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'" 758 759 """ 760 from meerschaum.utils.packages import attempt_import 761 from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, get_pd_type_from_db_type 762 dateutil_parser = attempt_import('dateutil.parser') 763 if 'int' in str(type(begin)).lower(): 764 num_str = str(begin) 765 if number is not None and number != 0: 766 num_str += ( 767 f' + {number}' 768 if number > 0 769 else f" - {number * -1}" 770 ) 771 return num_str 772 if datepart not in _VALID_DATEPARTS: 773 raise ValueError( 774 f"Invalid datepart '{datepart}'. Must be one of: {sorted(_VALID_DATEPARTS)}" 775 ) 776 if not begin: 777 return '' 778 779 _original_begin = begin 780 begin_time = None 781 ### Sanity check: make sure `begin` is a valid datetime before we inject anything. 782 if not isinstance(begin, datetime): 783 try: 784 begin_time = dateutil_parser.parse(begin) 785 except Exception: 786 begin_time = None 787 else: 788 begin_time = begin 789 790 ### Unable to parse into a datetime. 791 if begin_time is None: 792 ### Throw an error if banned symbols are included in the `begin` string. 793 clean(str(begin)) 794 ### If begin is a valid datetime, wrap it in quotes. 795 else: 796 if isinstance(begin, datetime) and begin.tzinfo is not None: 797 begin = begin.astimezone(timezone.utc) 798 begin = ( 799 f"'{begin.replace(tzinfo=None)}'" 800 if isinstance(begin, datetime) and flavor in TIMEZONE_NAIVE_FLAVORS 801 else f"'{begin}'" 802 ) 803 804 dt_is_utc = ( 805 begin_time.tzinfo is not None 806 if begin_time is not None 807 else ('+' in str(begin) or '-' in str(begin).split(':', maxsplit=1)[-1]) 808 ) 809 if db_type: 810 db_type_is_utc = 'utc' in get_pd_type_from_db_type(db_type).lower() 811 dt_is_utc = dt_is_utc or db_type_is_utc 812 db_type = db_type or get_db_type_from_pd_type( 813 ('datetime64[ns, UTC]' if dt_is_utc else 'datetime64[ns]'), 814 flavor=flavor, 815 ) 816 817 da = "" 818 if flavor in ( 819 'postgresql', 820 'postgis', 821 'timescaledb', 822 'timescaledb-ha', 823 'cockroachdb', 824 'citus', 825 ): 826 begin = ( 827 f"CAST({begin} AS {db_type})" if begin != 'now' 828 else f"CAST(NOW() AT TIME ZONE 'utc' AS {db_type})" 829 ) 830 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 831 832 elif flavor == 'duckdb': 833 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'NOW()' 834 da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '') 835 836 elif flavor in ('mssql',): 837 if begin_time and begin_time.microsecond != 0 and not dt_is_utc: 838 begin = begin[:-4] + "'" 839 begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'GETUTCDATE()' 840 if dt_is_utc: 841 begin += " AT TIME ZONE 'UTC'" 842 da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin 843 844 elif flavor in ('mysql', 'mariadb'): 845 begin = ( 846 f"CAST({begin} AS DATETIME(6))" 847 if begin != 'now' 848 else 'UTC_TIMESTAMP(6)' 849 ) 850 da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin) 851 852 elif flavor in ('sqlite', 'geopackage'): 853 da = f"datetime({begin}, '{number} {datepart}')" 854 855 elif flavor == 'oracle': 856 if begin == 'now': 857 begin = str( 858 datetime.now(timezone.utc).replace(tzinfo=None).strftime(r'%Y:%m:%d %M:%S.%f') 859 ) 860 elif begin_time: 861 begin = str(begin_time.strftime(r'%Y-%m-%d %H:%M:%S.%f')) 862 dt_format = 'YYYY-MM-DD HH24:MI:SS.FF' 863 _begin = f"'{begin}'" if begin_time else begin 864 da = ( 865 (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin) 866 + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "") 867 ) 868 return da
Generate a DATEADD clause depending on database flavor.
Parameters
flavor (str, default
'postgresql'): SQL database flavor, e.g.'postgresql','sqlite'.Currently supported flavors:
'postgresql''postgis''timescaledb''timescaledb-ha''citus''cockroachdb''duckdb''mssql''mysql''mariadb''sqlite''geopackage''oracle'
datepart (str, default
'day'): Which part of the date to modify. Supported values:'year''month''day''hour''minute''second'
- number (Union[int, float], default
0): How many units to add to the date part. - begin (Union[str, datetime], default
'now'): Base datetime to which to add dateparts. - db_type (Optional[str], default None): If provided, cast the datetime string as the type. Otherwise, infer this from the input datetime value.
Returns
- The appropriate
DATEADDstring for the corresponding database flavor.
Examples
>>> dateadd_str(
... flavor='mssql',
... begin=datetime(2022, 1, 1, 0, 0),
... number=1,
... )
"DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))"
>>> dateadd_str(
... flavor='postgresql',
... begin=datetime(2022, 1, 1, 0, 0),
... number=1,
... )
"CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'"
871def test_connection( 872 self, 873 **kw: Any 874) -> Union[bool, None]: 875 """ 876 Test if a successful connection to the database may be made. 877 878 Parameters 879 ---------- 880 **kw: 881 The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`. 882 883 Returns 884 ------- 885 `True` if a connection is made, otherwise `False` or `None` in case of failure. 886 887 """ 888 import warnings 889 from meerschaum.connectors.poll import retry_connect 890 _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self} 891 _default_kw.update(kw) 892 with warnings.catch_warnings(): 893 warnings.filterwarnings('ignore', 'Could not') 894 try: 895 return retry_connect(**_default_kw) 896 except Exception: 897 return False
Test if a successful connection to the database may be made.
Parameters
- **kw:: The keyword arguments are passed to
meerschaum.connectors.poll.retry_connect.
Returns
Trueif a connection is made, otherwiseFalseorNonein case of failure.
900def get_distinct_col_count( 901 col: str, 902 query: str, 903 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 904 debug: bool = False 905) -> Optional[int]: 906 """ 907 Returns the number of distinct items in a column of a SQL query. 908 909 Parameters 910 ---------- 911 col: str: 912 The column in the query to count. 913 914 query: str: 915 The SQL query to count from. 916 917 connector: Optional[mrsm.connectors.sql.SQLConnector], default None: 918 The SQLConnector to execute the query. 919 920 debug: bool, default False: 921 Verbosity toggle. 922 923 Returns 924 ------- 925 An `int` of the number of columns in the query or `None` if the query fails. 926 927 """ 928 if connector is None: 929 connector = mrsm.get_connector('sql') 930 931 _col_name = sql_item_name(col, connector.flavor, None) 932 933 _meta_query = ( 934 f""" 935 WITH src AS ( {query} ), 936 dist AS ( SELECT DISTINCT {_col_name} FROM src ) 937 SELECT COUNT(*) FROM dist""" 938 ) if connector.flavor not in ('mysql', 'mariadb') else ( 939 f""" 940 SELECT COUNT(*) 941 FROM ( 942 SELECT DISTINCT {_col_name} 943 FROM ({query}) AS src 944 ) AS dist""" 945 ) 946 947 result = connector.value(_meta_query, debug=debug) 948 try: 949 return int(result) 950 except Exception: 951 return None
Returns the number of distinct items in a column of a SQL query.
Parameters
- col (str:): The column in the query to count.
- query (str:): The SQL query to count from.
- connector (Optional[mrsm.connectors.sql.SQLConnector], default None:): The SQLConnector to execute the query.
- debug (bool, default False:): Verbosity toggle.
Returns
- An
intof the number of columns in the query orNoneif the query fails.
954def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str: 955 """ 956 Parse SQL items depending on the flavor. 957 958 Parameters 959 ---------- 960 item: str 961 The database item (table, view, etc.) in need of quotes. 962 963 flavor: str 964 The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.). 965 966 schema: Optional[str], default None 967 If provided, prefix the table name with the schema. 968 969 Returns 970 ------- 971 A `str` which contains the input `item` wrapped in the corresponding escape characters. 972 973 Examples 974 -------- 975 >>> sql_item_name('table', 'sqlite') 976 '"table"' 977 >>> sql_item_name('table', 'mssql') 978 "[table]" 979 >>> sql_item_name('table', 'postgresql', schema='abc') 980 '"abc"."table"' 981 982 """ 983 truncated_item = truncate_item_name(str(item), flavor) 984 if flavor == 'oracle': 985 truncated_item = pg_capital(truncated_item, quote_capitals=True) 986 ### NOTE: System-reserved words must be quoted. 987 if truncated_item.lower() in ( 988 'float', 'varchar', 'nvarchar', 'clob', 989 'boolean', 'integer', 'table', 'row', 'date', 'time', 990 ): 991 wrappers = ('"', '"') 992 else: 993 wrappers = ('', '') 994 else: 995 wrappers = table_wrappers.get(flavor, table_wrappers['default']) 996 997 ### NOTE: SQLite does not support schemas. 998 if flavor in ('sqlite', 'geopackage'): 999 schema = None 1000 elif flavor == 'mssql' and str(item).startswith('#'): 1001 schema = None 1002 1003 schema_prefix = ( 1004 (wrappers[0] + schema + wrappers[1] + '.') 1005 if schema is not None 1006 else '' 1007 ) 1008 1009 return schema_prefix + wrappers[0] + truncated_item + wrappers[1]
Parse SQL items depending on the flavor.
Parameters
- item (str): The database item (table, view, etc.) in need of quotes.
- flavor (str):
The database flavor (
'postgresql','mssql','sqllite', etc.). - schema (Optional[str], default None): If provided, prefix the table name with the schema.
Returns
- A
strwhich contains the inputitemwrapped in the corresponding escape characters.
Examples
>>> sql_item_name('table', 'sqlite')
'"table"'
>>> sql_item_name('table', 'mssql')
"[table]"
>>> sql_item_name('table', 'postgresql', schema='abc')
'"abc"."table"'
1012def pg_capital(s: str, quote_capitals: bool = True) -> str: 1013 """ 1014 If string contains a capital letter, wrap it in double quotes. 1015 1016 Parameters 1017 ---------- 1018 s: str 1019 The string to be escaped. 1020 1021 quote_capitals: bool, default True 1022 If `False`, do not quote strings with contain only a mix of capital and lower-case letters. 1023 1024 Returns 1025 ------- 1026 The input string wrapped in quotes only if it needs them. 1027 1028 Examples 1029 -------- 1030 >>> pg_capital("My Table") 1031 '"My Table"' 1032 >>> pg_capital('my_table') 1033 'my_table' 1034 1035 """ 1036 if s.startswith('"') and s.endswith('"'): 1037 return s 1038 1039 s = s.replace('"', '') 1040 1041 needs_quotes = s.startswith('_') 1042 if not needs_quotes: 1043 for c in s: 1044 if c == '_': 1045 continue 1046 1047 if not c.isalnum() or (quote_capitals and c.isupper()): 1048 needs_quotes = True 1049 break 1050 1051 if needs_quotes: 1052 return '"' + s + '"' 1053 1054 return s
If string contains a capital letter, wrap it in double quotes.
Parameters
- s (str): The string to be escaped.
- quote_capitals (bool, default True):
If
False, do not quote strings with contain only a mix of capital and lower-case letters.
Returns
- The input string wrapped in quotes only if it needs them.
Examples
>>> pg_capital("My Table")
'"My Table"'
>>> pg_capital('my_table')
'my_table'
1057def oracle_capital(s: str) -> str: 1058 """ 1059 Capitalize the string of an item on an Oracle database. 1060 """ 1061 return s
Capitalize the string of an item on an Oracle database.
1064def truncate_item_name(item: str, flavor: str) -> str: 1065 """ 1066 Truncate item names to stay within the database flavor's character limit. 1067 1068 Parameters 1069 ---------- 1070 item: str 1071 The database item being referenced. This string is the "canonical" name internally. 1072 1073 flavor: str 1074 The flavor of the database on which `item` resides. 1075 1076 Returns 1077 ------- 1078 The truncated string. 1079 """ 1080 from meerschaum.utils.misc import truncate_string_sections 1081 return truncate_string_sections( 1082 item, max_len=max_name_lens.get(flavor, max_name_lens['default']) 1083 )
Truncate item names to stay within the database flavor's character limit.
Parameters
- item (str): The database item being referenced. This string is the "canonical" name internally.
- flavor (str):
The flavor of the database on which
itemresides.
Returns
- The truncated string.
1086def build_where( 1087 params: Dict[str, Any], 1088 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 1089 with_where: bool = True, 1090 flavor: str = 'postgresql', 1091) -> str: 1092 """ 1093 Build the `WHERE` clause based on the input criteria. 1094 1095 Parameters 1096 ---------- 1097 params: Dict[str, Any]: 1098 The keywords dictionary to convert into a WHERE clause. 1099 If a value is a string which begins with an underscore, negate that value 1100 (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`). 1101 A value of `_None` will be interpreted as `IS NOT NULL`. 1102 1103 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 1104 The Meerschaum SQLConnector that will be executing the query. 1105 The connector is used to extract the SQL dialect. 1106 1107 with_where: bool, default True: 1108 If `True`, include the leading `'WHERE'` string. 1109 1110 flavor: str, default 'postgresql' 1111 If `connector` is `None`, fall back to this flavor. 1112 1113 Returns 1114 ------- 1115 A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor. 1116 1117 Examples 1118 -------- 1119 ``` 1120 >>> print(build_where({'foo': [1, 2, 3]})) 1121 1122 WHERE 1123 "foo" IN ('1', '2', '3') 1124 ``` 1125 """ 1126 import json 1127 from meerschaum._internal.static import STATIC_CONFIG 1128 from meerschaum.utils.warnings import warn 1129 from meerschaum.utils.dtypes import value_is_null, none_if_null 1130 negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix'] 1131 try: 1132 params_json = json.dumps(params) 1133 except Exception: 1134 params_json = str(params) 1135 bad_words = ['drop ', '--', ';'] 1136 for word in bad_words: 1137 if word in params_json.lower(): 1138 warn("Aborting build_where() due to possible SQL injection.") 1139 return '' 1140 1141 query_flavor = getattr(connector, 'flavor', flavor) if connector is not None else flavor 1142 where = "" 1143 leading_and = "\n AND " 1144 for key, value in params.items(): 1145 _key = sql_item_name(key, query_flavor, None) 1146 ### search across a list (i.e. IN syntax) 1147 if isinstance(value, Iterable) and not isinstance(value, (dict, str)): 1148 includes = [ 1149 none_if_null(item) 1150 for item in value 1151 if not str(item).startswith(negation_prefix) 1152 ] 1153 null_includes = [item for item in includes if item is None] 1154 not_null_includes = [item for item in includes if item is not None] 1155 excludes = [ 1156 none_if_null(str(item)[len(negation_prefix):]) 1157 for item in value 1158 if str(item).startswith(negation_prefix) 1159 ] 1160 null_excludes = [item for item in excludes if item is None] 1161 not_null_excludes = [item for item in excludes if item is not None] 1162 1163 if includes: 1164 where += f"{leading_and}(" 1165 if not_null_includes: 1166 where += f"{_key} IN (" 1167 for item in not_null_includes: 1168 quoted_item = str(item).replace("'", "''") 1169 where += f"'{quoted_item}', " 1170 where = where[:-2] + ")" 1171 if null_includes: 1172 where += ("\n OR " if not_null_includes else "") + f"{_key} IS NULL" 1173 if includes: 1174 where += ")" 1175 1176 if excludes: 1177 where += f"{leading_and}(" 1178 if not_null_excludes: 1179 where += f"{_key} NOT IN (" 1180 for item in not_null_excludes: 1181 quoted_item = str(item).replace("'", "''") 1182 where += f"'{quoted_item}', " 1183 where = where[:-2] + ")" 1184 if null_excludes: 1185 where += ("\n AND " if not_null_excludes else "") + f"{_key} IS NOT NULL" 1186 if excludes: 1187 where += ")" 1188 1189 continue 1190 1191 ### search a dictionary 1192 elif isinstance(value, dict): 1193 import json 1194 where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'") 1195 continue 1196 1197 eq_sign = '=' 1198 is_null = 'IS NULL' 1199 if value_is_null(str(value).lstrip(negation_prefix)): 1200 value = ( 1201 (negation_prefix + 'None') 1202 if str(value).startswith(negation_prefix) 1203 else None 1204 ) 1205 if str(value).startswith(negation_prefix): 1206 value = str(value)[len(negation_prefix):] 1207 eq_sign = '!=' 1208 if value_is_null(value): 1209 value = None 1210 is_null = 'IS NOT NULL' 1211 quoted_value = str(value).replace("'", "''") 1212 where += ( 1213 f"{leading_and}{_key} " 1214 + (is_null if value is None else f"{eq_sign} '{quoted_value}'") 1215 ) 1216 1217 if len(where) > 1: 1218 where = ("\nWHERE\n " if with_where else '') + where[len(leading_and):] 1219 return where
Build the WHERE clause based on the input criteria.
Parameters
- params (Dict[str, Any]:):
The keywords dictionary to convert into a WHERE clause.
If a value is a string which begins with an underscore, negate that value
(e.g.
!=instead of=orNOT INinstead ofIN). A value of_Nonewill be interpreted asIS NOT NULL. - connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The Meerschaum SQLConnector that will be executing the query. The connector is used to extract the SQL dialect.
- with_where (bool, default True:):
If
True, include the leading'WHERE'string. - flavor (str, default 'postgresql'):
If
connectorisNone, fall back to this flavor.
Returns
- A
strof theWHEREclause from the inputparamsdictionary for the connector's flavor.
Examples
>>> print(build_where({'foo': [1, 2, 3]}))
WHERE
"foo" IN ('1', '2', '3')
1222def table_exists( 1223 table: str, 1224 connector: mrsm.connectors.sql.SQLConnector, 1225 schema: Optional[str] = None, 1226 debug: bool = False, 1227) -> bool: 1228 """Check if a table exists. 1229 1230 Parameters 1231 ---------- 1232 table: str: 1233 The name of the table in question. 1234 1235 connector: mrsm.connectors.sql.SQLConnector 1236 The connector to the database which holds the table. 1237 1238 schema: Optional[str], default None 1239 Optionally specify the table schema. 1240 Defaults to `connector.schema`. 1241 1242 debug: bool, default False : 1243 Verbosity toggle. 1244 1245 Returns 1246 ------- 1247 A `bool` indicating whether or not the table exists on the database. 1248 """ 1249 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1250 schema = schema or connector.schema 1251 insp = sqlalchemy.inspect(connector.engine) 1252 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1253 return insp.has_table(truncated_table_name, schema=schema)
Check if a table exists.
Parameters
- table (str:): The name of the table in question.
- connector (mrsm.connectors.sql.SQLConnector): The connector to the database which holds the table.
- schema (Optional[str], default None):
Optionally specify the table schema.
Defaults to
connector.schema. - debug (bool, default False :): Verbosity toggle.
Returns
- A
boolindicating whether or not the table exists on the database.
1256def get_sqlalchemy_table( 1257 table: str, 1258 connector: Optional[mrsm.connectors.sql.SQLConnector] = None, 1259 schema: Optional[str] = None, 1260 refresh: bool = False, 1261 debug: bool = False, 1262) -> Union['sqlalchemy.Table', None]: 1263 """ 1264 Construct a SQLAlchemy table from its name. 1265 1266 Parameters 1267 ---------- 1268 table: str 1269 The name of the table on the database. Does not need to be escaped. 1270 1271 connector: Optional[meerschaum.connectors.sql.SQLConnector], default None: 1272 The connector to the database which holds the table. 1273 1274 schema: Optional[str], default None 1275 Specify on which schema the table resides. 1276 Defaults to the schema set in `connector`. 1277 1278 refresh: bool, default False 1279 If `True`, rebuild the cached table object. 1280 1281 debug: bool, default False: 1282 Verbosity toggle. 1283 1284 Returns 1285 ------- 1286 A `sqlalchemy.Table` object for the table. 1287 1288 """ 1289 if connector is None: 1290 from meerschaum import get_connector 1291 connector = get_connector('sql') 1292 1293 if connector.flavor == 'duckdb': 1294 return None 1295 1296 from meerschaum.connectors.sql.tables import get_tables 1297 from meerschaum.utils.packages import attempt_import 1298 from meerschaum.utils.warnings import warn 1299 if refresh: 1300 connector.metadata.clear() 1301 tables = get_tables(mrsm_instance=connector, debug=debug, create=False) 1302 sqlalchemy = attempt_import('sqlalchemy', lazy=False) 1303 truncated_table_name = truncate_item_name(str(table), connector.flavor) 1304 table_kwargs = { 1305 'autoload_with': connector.engine, 1306 } 1307 if schema: 1308 table_kwargs['schema'] = schema 1309 1310 if refresh or truncated_table_name not in tables: 1311 try: 1312 tables[truncated_table_name] = sqlalchemy.Table( 1313 truncated_table_name, 1314 connector.metadata, 1315 **table_kwargs 1316 ) 1317 except sqlalchemy.exc.NoSuchTableError: 1318 warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.") 1319 return None 1320 return tables[truncated_table_name]
Construct a SQLAlchemy table from its name.
Parameters
- table (str): The name of the table on the database. Does not need to be escaped.
- connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The connector to the database which holds the table.
- schema (Optional[str], default None):
Specify on which schema the table resides.
Defaults to the schema set in
connector. - refresh (bool, default False):
If
True, rebuild the cached table object. - debug (bool, default False:): Verbosity toggle.
Returns
- A
sqlalchemy.Tableobject for the table.
1323def get_table_cols_types( 1324 table: str, 1325 connectable: Union[ 1326 'mrsm.connectors.sql.SQLConnector', 1327 'sqlalchemy.orm.session.Session', 1328 'sqlalchemy.engine.base.Engine' 1329 ], 1330 flavor: Optional[str] = None, 1331 schema: Optional[str] = None, 1332 database: Optional[str] = None, 1333 debug: bool = False, 1334) -> Dict[str, str]: 1335 """ 1336 Return a dictionary mapping a table's columns to data types. 1337 This is useful for inspecting tables creating during a not-yet-committed session. 1338 1339 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1340 Use this function if you are confident the table name is unique or if you have 1341 and explicit schema. 1342 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1343 1344 Parameters 1345 ---------- 1346 table: str 1347 The name of the table (unquoted). 1348 1349 connectable: Union[ 1350 'mrsm.connectors.sql.SQLConnector', 1351 'sqlalchemy.orm.session.Session', 1352 'sqlalchemy.engine.base.Engine' 1353 ] 1354 The connection object used to fetch the columns and types. 1355 1356 flavor: Optional[str], default None 1357 The database dialect flavor to use for the query. 1358 If omitted, default to `connectable.flavor`. 1359 1360 schema: Optional[str], default None 1361 If provided, restrict the query to this schema. 1362 1363 database: Optional[str]. default None 1364 If provided, restrict the query to this database. 1365 1366 Returns 1367 ------- 1368 A dictionary mapping column names to data types. 1369 """ 1370 import textwrap 1371 from meerschaum.connectors import SQLConnector 1372 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1373 flavor = flavor or getattr(connectable, 'flavor', None) 1374 if not flavor: 1375 raise ValueError("Please provide a database flavor.") 1376 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1377 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1378 if flavor in NO_SCHEMA_FLAVORS: 1379 schema = None 1380 if schema is None: 1381 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1382 if flavor in ('sqlite', 'duckdb', 'oracle', 'geopackage'): 1383 database = None 1384 table_trunc = truncate_item_name(table, flavor=flavor) 1385 table_lower = table.lower() 1386 table_upper = table.upper() 1387 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1388 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1389 db_prefix = ( 1390 "tempdb." 1391 if flavor == 'mssql' and table.startswith('#') 1392 else "" 1393 ) 1394 1395 def _esc(s: str) -> str: 1396 return s.replace("'", "''") 1397 1398 cols_types_query = sqlalchemy.text( 1399 textwrap.dedent(columns_types_queries.get( 1400 flavor, 1401 columns_types_queries['default'] 1402 ).format( 1403 table=_esc(table), 1404 table_trunc=_esc(table_trunc), 1405 table_lower=_esc(table_lower), 1406 table_lower_trunc=_esc(table_lower_trunc), 1407 table_upper=_esc(table_upper), 1408 table_upper_trunc=_esc(table_upper_trunc), 1409 db_prefix=db_prefix, 1410 )).lstrip().rstrip() 1411 ) 1412 1413 cols = ['database', 'schema', 'table', 'column', 'type', 'numeric_precision', 'numeric_scale'] 1414 result_cols_ix = dict(enumerate(cols)) 1415 1416 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1417 if not debug_kwargs and debug: 1418 dprint(cols_types_query) 1419 1420 try: 1421 result_rows = ( 1422 [ 1423 row 1424 for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall() 1425 ] 1426 if flavor != 'duckdb' 1427 else [ 1428 tuple([doc[col] for col in cols]) 1429 for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records') 1430 ] 1431 ) 1432 cols_types_docs = [ 1433 { 1434 result_cols_ix[i]: val 1435 for i, val in enumerate(row) 1436 } 1437 for row in result_rows 1438 ] 1439 cols_types_docs_filtered = [ 1440 doc 1441 for doc in cols_types_docs 1442 if ( 1443 ( 1444 not schema 1445 or doc['schema'] == schema 1446 ) 1447 and 1448 ( 1449 not database 1450 or doc['database'] == database 1451 ) 1452 ) 1453 ] 1454 1455 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1456 if cols_types_docs and not cols_types_docs_filtered: 1457 cols_types_docs_filtered = cols_types_docs 1458 1459 ### NOTE: Check for PostGIS GEOMETRY columns. 1460 geometry_cols_types = {} 1461 user_defined_cols = [ 1462 doc 1463 for doc in cols_types_docs_filtered 1464 if str(doc.get('type', None)).upper() == 'USER-DEFINED' 1465 ] 1466 if user_defined_cols: 1467 geometry_cols_types.update( 1468 get_postgis_geo_columns_types( 1469 connectable, 1470 table, 1471 schema=schema, 1472 debug=debug, 1473 ) 1474 ) 1475 1476 cols_types = { 1477 ( 1478 doc['column'] 1479 if flavor != 'oracle' else ( 1480 ( 1481 doc['column'].lower() 1482 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1483 else doc['column'] 1484 ) 1485 ) 1486 ): doc['type'].upper() + ( 1487 f'({precision},{scale})' 1488 if ( 1489 (precision := doc.get('numeric_precision', None)) 1490 and 1491 (scale := doc.get('numeric_scale', None)) 1492 ) 1493 else '' 1494 ) 1495 for doc in cols_types_docs_filtered 1496 } 1497 cols_types.update(geometry_cols_types) 1498 return cols_types 1499 except Exception as e: 1500 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1501 return {}
Return a dictionary mapping a table's columns to data types. This is useful for inspecting tables creating during a not-yet-committed session.
NOTE: This may return incorrect columns if the schema is not explicitly stated.
Use this function if you are confident the table name is unique or if you have
and explicit schema.
To use the configured schema, get the columns from get_sqlalchemy_table() instead.
Parameters
- table (str): The name of the table (unquoted).
- connectable (Union[): 'mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine'
- ]: The connection object used to fetch the columns and types.
- flavor (Optional[str], default None):
The database dialect flavor to use for the query.
If omitted, default to
connectable.flavor. - schema (Optional[str], default None): If provided, restrict the query to this schema.
- database (Optional[str]. default None): If provided, restrict the query to this database.
Returns
- A dictionary mapping column names to data types.
1504def get_table_cols_indices( 1505 table: str, 1506 connectable: Union[ 1507 'mrsm.connectors.sql.SQLConnector', 1508 'sqlalchemy.orm.session.Session', 1509 'sqlalchemy.engine.base.Engine' 1510 ], 1511 flavor: Optional[str] = None, 1512 schema: Optional[str] = None, 1513 database: Optional[str] = None, 1514 debug: bool = False, 1515) -> Dict[str, List[str]]: 1516 """ 1517 Return a dictionary mapping a table's columns to lists of indices. 1518 This is useful for inspecting tables creating during a not-yet-committed session. 1519 1520 NOTE: This may return incorrect columns if the schema is not explicitly stated. 1521 Use this function if you are confident the table name is unique or if you have 1522 and explicit schema. 1523 To use the configured schema, get the columns from `get_sqlalchemy_table()` instead. 1524 1525 Parameters 1526 ---------- 1527 table: str 1528 The name of the table (unquoted). 1529 1530 connectable: Union[ 1531 'mrsm.connectors.sql.SQLConnector', 1532 'sqlalchemy.orm.session.Session', 1533 'sqlalchemy.engine.base.Engine' 1534 ] 1535 The connection object used to fetch the columns and types. 1536 1537 flavor: Optional[str], default None 1538 The database dialect flavor to use for the query. 1539 If omitted, default to `connectable.flavor`. 1540 1541 schema: Optional[str], default None 1542 If provided, restrict the query to this schema. 1543 1544 database: Optional[str]. default None 1545 If provided, restrict the query to this database. 1546 1547 Returns 1548 ------- 1549 A dictionary mapping column names to a list of indices. 1550 """ 1551 import textwrap 1552 from collections import defaultdict 1553 from meerschaum.connectors import SQLConnector 1554 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 1555 flavor = flavor or getattr(connectable, 'flavor', None) 1556 if not flavor: 1557 raise ValueError("Please provide a database flavor.") 1558 if flavor == 'duckdb' and not isinstance(connectable, SQLConnector): 1559 raise ValueError("You must provide a SQLConnector when using DuckDB.") 1560 if flavor in NO_SCHEMA_FLAVORS: 1561 schema = None 1562 if schema is None: 1563 schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None) 1564 if flavor in ('sqlite', 'duckdb', 'oracle', 'geopackage'): 1565 database = None 1566 table_trunc = truncate_item_name(table, flavor=flavor) 1567 table_lower = table.lower() 1568 table_upper = table.upper() 1569 table_lower_trunc = truncate_item_name(table_lower, flavor=flavor) 1570 table_upper_trunc = truncate_item_name(table_upper, flavor=flavor) 1571 db_prefix = ( 1572 "tempdb." 1573 if flavor == 'mssql' and table.startswith('#') 1574 else "" 1575 ) 1576 1577 def _esc(s: str) -> str: 1578 return s.replace("'", "''") 1579 1580 cols_indices_query = sqlalchemy.text( 1581 textwrap.dedent(columns_indices_queries.get( 1582 flavor, 1583 columns_indices_queries['default'] 1584 ).format( 1585 table=_esc(table), 1586 table_trunc=_esc(table_trunc), 1587 table_lower=_esc(table_lower), 1588 table_lower_trunc=_esc(table_lower_trunc), 1589 table_upper=_esc(table_upper), 1590 table_upper_trunc=_esc(table_upper_trunc), 1591 db_prefix=db_prefix, 1592 schema=_esc(schema) if schema else schema, 1593 )).lstrip().rstrip() 1594 ) 1595 1596 cols = ['database', 'schema', 'table', 'column', 'index', 'index_type'] 1597 if flavor == 'mssql': 1598 cols.append('clustered') 1599 result_cols_ix = dict(enumerate(cols)) 1600 1601 debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {} 1602 if not debug_kwargs and debug: 1603 dprint(cols_indices_query) 1604 1605 try: 1606 result_rows = ( 1607 [ 1608 row 1609 for row in connectable.execute(cols_indices_query, **debug_kwargs).fetchall() 1610 ] 1611 if flavor != 'duckdb' 1612 else [ 1613 tuple([doc[col] for col in cols]) 1614 for doc in connectable.read(cols_indices_query, debug=debug).to_dict(orient='records') 1615 ] 1616 ) 1617 cols_types_docs = [ 1618 { 1619 result_cols_ix[i]: val 1620 for i, val in enumerate(row) 1621 } 1622 for row in result_rows 1623 ] 1624 cols_types_docs_filtered = [ 1625 doc 1626 for doc in cols_types_docs 1627 if ( 1628 ( 1629 not schema 1630 or doc['schema'] == schema 1631 ) 1632 and 1633 ( 1634 not database 1635 or doc['database'] == database 1636 ) 1637 ) 1638 ] 1639 ### NOTE: This may return incorrect columns if the schema is not explicitly stated. 1640 if cols_types_docs and not cols_types_docs_filtered: 1641 cols_types_docs_filtered = cols_types_docs 1642 1643 cols_indices = defaultdict(lambda: []) 1644 for doc in cols_types_docs_filtered: 1645 col = ( 1646 doc['column'] 1647 if flavor != 'oracle' 1648 else ( 1649 doc['column'].lower() 1650 if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha()) 1651 else doc['column'] 1652 ) 1653 ) 1654 index_doc = { 1655 'name': doc.get('index', None), 1656 'type': doc.get('index_type', None) 1657 } 1658 if flavor == 'mssql': 1659 index_doc['clustered'] = doc.get('clustered', None) 1660 cols_indices[col].append(index_doc) 1661 1662 return dict(cols_indices) 1663 except Exception as e: 1664 warn(f"Failed to fetch columns for table '{table}':\n{e}") 1665 return {}
Return a dictionary mapping a table's columns to lists of indices. This is useful for inspecting tables creating during a not-yet-committed session.
NOTE: This may return incorrect columns if the schema is not explicitly stated.
Use this function if you are confident the table name is unique or if you have
and explicit schema.
To use the configured schema, get the columns from get_sqlalchemy_table() instead.
Parameters
- table (str): The name of the table (unquoted).
- connectable (Union[): 'mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine'
- ]: The connection object used to fetch the columns and types.
- flavor (Optional[str], default None):
The database dialect flavor to use for the query.
If omitted, default to
connectable.flavor. - schema (Optional[str], default None): If provided, restrict the query to this schema.
- database (Optional[str]. default None): If provided, restrict the query to this database.
Returns
- A dictionary mapping column names to a list of indices.
1668def get_update_queries( 1669 target: str, 1670 patch: str, 1671 connectable: Union[ 1672 mrsm.connectors.sql.SQLConnector, 1673 'sqlalchemy.orm.session.Session' 1674 ], 1675 join_cols: Iterable[str], 1676 flavor: Optional[str] = None, 1677 upsert: bool = False, 1678 datetime_col: Optional[str] = None, 1679 schema: Optional[str] = None, 1680 patch_schema: Optional[str] = None, 1681 target_cols_types: Optional[Dict[str, str]] = None, 1682 patch_cols_types: Optional[Dict[str, str]] = None, 1683 identity_insert: bool = False, 1684 null_indices: bool = True, 1685 cast_columns: bool = True, 1686 debug: bool = False, 1687) -> List[str]: 1688 """ 1689 Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table. 1690 1691 Parameters 1692 ---------- 1693 target: str 1694 The name of the target table. 1695 1696 patch: str 1697 The name of the patch table. This should have the same shape as the target. 1698 1699 connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session] 1700 The `SQLConnector` or SQLAlchemy session which will later execute the queries. 1701 1702 join_cols: List[str] 1703 The columns to use to join the patch to the target. 1704 1705 flavor: Optional[str], default None 1706 If using a SQLAlchemy session, provide the expected database flavor. 1707 1708 upsert: bool, default False 1709 If `True`, return an upsert query rather than an update. 1710 1711 datetime_col: Optional[str], default None 1712 If provided, bound the join query using this column as the datetime index. 1713 This must be present on both tables. 1714 1715 schema: Optional[str], default None 1716 If provided, use this schema when quoting the target table. 1717 Defaults to `connector.schema`. 1718 1719 patch_schema: Optional[str], default None 1720 If provided, use this schema when quoting the patch table. 1721 Defaults to `schema`. 1722 1723 target_cols_types: Optional[Dict[str, Any]], default None 1724 If provided, use these as the columns-types dictionary for the target table. 1725 Default will infer from the database context. 1726 1727 patch_cols_types: Optional[Dict[str, Any]], default None 1728 If provided, use these as the columns-types dictionary for the target table. 1729 Default will infer from the database context. 1730 1731 identity_insert: bool, default False 1732 If `True`, include `SET IDENTITY_INSERT` queries before and after the update queries. 1733 Only applies for MSSQL upserts. 1734 1735 null_indices: bool, default True 1736 If `False`, do not coalesce index columns before joining. 1737 1738 cast_columns: bool, default True 1739 If `False`, do not cast update columns to the target table types. 1740 1741 debug: bool, default False 1742 Verbosity toggle. 1743 1744 Returns 1745 ------- 1746 A list of query strings to perform the update operation. 1747 """ 1748 import textwrap 1749 from meerschaum.connectors import SQLConnector 1750 from meerschaum.utils.debug import dprint 1751 from meerschaum.utils.dtypes import are_dtypes_equal 1752 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES, get_pd_type_from_db_type 1753 flavor = flavor or getattr(connectable, 'flavor', None) 1754 if not flavor: 1755 raise ValueError("Provide a flavor if using a SQLAlchemy session.") 1756 if ( 1757 flavor in ('sqlite', 'geopackage') 1758 and isinstance(connectable, SQLConnector) 1759 and connectable.db_version < '3.33.0' 1760 ): 1761 flavor = 'sqlite_delete_insert' 1762 flavor_key = (f'{flavor}-upsert' if upsert else flavor) 1763 base_queries = UPDATE_QUERIES.get( 1764 flavor_key, 1765 UPDATE_QUERIES['default'] 1766 ) 1767 if not isinstance(base_queries, list): 1768 base_queries = [base_queries] 1769 schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None) 1770 patch_schema = patch_schema or schema 1771 target_table_columns = get_table_cols_types( 1772 target, 1773 connectable, 1774 flavor=flavor, 1775 schema=schema, 1776 debug=debug, 1777 ) if not target_cols_types else target_cols_types 1778 patch_table_columns = get_table_cols_types( 1779 patch, 1780 connectable, 1781 flavor=flavor, 1782 schema=patch_schema, 1783 debug=debug, 1784 ) if not patch_cols_types else patch_cols_types 1785 1786 patch_cols_str = ', '.join( 1787 [ 1788 sql_item_name(col, flavor) 1789 for col in patch_table_columns 1790 ] 1791 ) 1792 patch_cols_prefixed_str = ', '.join( 1793 [ 1794 'p.' + sql_item_name(col, flavor) 1795 for col in patch_table_columns 1796 ] 1797 ) 1798 1799 join_cols_str = ', '.join( 1800 [ 1801 sql_item_name(col, flavor) 1802 for col in join_cols 1803 ] 1804 ) 1805 1806 value_cols = [] 1807 join_cols_types = [] 1808 if debug: 1809 dprint("target_table_columns:") 1810 mrsm.pprint(target_table_columns) 1811 for c_name, c_type in target_table_columns.items(): 1812 if c_name not in patch_table_columns: 1813 continue 1814 if flavor in DB_FLAVORS_CAST_DTYPES: 1815 c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type) 1816 ( 1817 join_cols_types 1818 if c_name in join_cols 1819 else value_cols 1820 ).append((c_name, c_type)) 1821 if debug: 1822 dprint(f"value_cols: {value_cols}") 1823 1824 if not join_cols_types: 1825 return [] 1826 if not value_cols and not upsert: 1827 return [] 1828 1829 coalesce_join_cols_str = ', '.join( 1830 [ 1831 ( 1832 ( 1833 'COALESCE(' 1834 + sql_item_name(c_name, flavor) 1835 + ', ' 1836 + get_null_replacement(c_type, flavor) 1837 + ')' 1838 ) 1839 if null_indices 1840 else sql_item_name(c_name, flavor) 1841 ) 1842 for c_name, c_type in join_cols_types 1843 ] 1844 ) 1845 1846 update_or_nothing = ('UPDATE' if value_cols else 'NOTHING') 1847 1848 def sets_subquery(l_prefix: str, r_prefix: str): 1849 if not value_cols: 1850 return '' 1851 1852 utc_value_cols = { 1853 c_name 1854 for c_name, c_type in value_cols 1855 if ('utc' in get_pd_type_from_db_type(c_type).lower()) 1856 } if flavor not in TIMEZONE_NAIVE_FLAVORS else set() 1857 1858 cast_func_cols = { 1859 c_name: ( 1860 ('', '', '') 1861 if not cast_columns or ( 1862 flavor == 'oracle' 1863 and are_dtypes_equal(get_pd_type_from_db_type(c_type), 'bytes') 1864 ) 1865 else ( 1866 ('CAST(', f" AS {c_type.replace('_', ' ')}", ')' + ( 1867 " AT TIME ZONE 'UTC'" 1868 if c_name in utc_value_cols 1869 else '' 1870 )) 1871 if flavor not in ('sqlite', 'geopackage') 1872 else ('', '', '') 1873 ) 1874 ) 1875 for c_name, c_type in value_cols 1876 } 1877 return 'SET ' + ',\n'.join([ 1878 ( 1879 l_prefix + sql_item_name(c_name, flavor, None) 1880 + ' = ' 1881 + cast_func_cols[c_name][0] 1882 + r_prefix + sql_item_name(c_name, flavor, None) 1883 + cast_func_cols[c_name][1] 1884 + cast_func_cols[c_name][2] 1885 ) for c_name, c_type in value_cols 1886 ]) 1887 1888 def and_subquery(l_prefix: str, r_prefix: str): 1889 return '\n AND\n '.join([ 1890 ( 1891 ( 1892 "COALESCE(" 1893 + l_prefix 1894 + sql_item_name(c_name, flavor, None) 1895 + ", " 1896 + get_null_replacement(c_type, flavor) 1897 + ")" 1898 + '\n =\n ' 1899 + "COALESCE(" 1900 + r_prefix 1901 + sql_item_name(c_name, flavor, None) 1902 + ", " 1903 + get_null_replacement(c_type, flavor) 1904 + ")" 1905 ) 1906 if null_indices 1907 else ( 1908 l_prefix 1909 + sql_item_name(c_name, flavor, None) 1910 + ' = ' 1911 + r_prefix 1912 + sql_item_name(c_name, flavor, None) 1913 ) 1914 ) for c_name, c_type in join_cols_types 1915 ]) 1916 1917 skip_query_val = "" 1918 target_table_name = sql_item_name(target, flavor, schema) 1919 patch_table_name = sql_item_name(patch, flavor, patch_schema) 1920 dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None 1921 date_bounds_table = patch_table_name if flavor != 'mssql' else '[date_bounds]' 1922 min_dt_col_name = f"MIN({dt_col_name})" if flavor != 'mssql' else '[Min_dt]' 1923 max_dt_col_name = f"MAX({dt_col_name})" if flavor != 'mssql' else '[Max_dt]' 1924 date_bounds_subquery = ( 1925 f"""f.{dt_col_name} >= (SELECT {min_dt_col_name} FROM {date_bounds_table}) 1926 AND 1927 f.{dt_col_name} <= (SELECT {max_dt_col_name} FROM {date_bounds_table})""" 1928 if datetime_col 1929 else "1 = 1" 1930 ) 1931 with_temp_date_bounds = f"""WITH [date_bounds] AS ( 1932 SELECT MIN({dt_col_name}) AS {min_dt_col_name}, MAX({dt_col_name}) AS {max_dt_col_name} 1933 FROM {patch_table_name} 1934 )""" if datetime_col else "" 1935 identity_insert_on = ( 1936 f"SET IDENTITY_INSERT {target_table_name} ON" 1937 if identity_insert 1938 else skip_query_val 1939 ) 1940 identity_insert_off = ( 1941 f"SET IDENTITY_INSERT {target_table_name} OFF" 1942 if identity_insert 1943 else skip_query_val 1944 ) 1945 1946 ### NOTE: MSSQL upserts must exclude the update portion if only upserting indices. 1947 when_matched_update_sets_subquery_none = "" if not value_cols else ( 1948 "\n WHEN MATCHED THEN\n" 1949 f" UPDATE {sets_subquery('', 'p.')}" 1950 ) 1951 1952 cols_equal_values = '\n,'.join( 1953 [ 1954 f"{sql_item_name(c_name, flavor)} = VALUES({sql_item_name(c_name, flavor)})" 1955 for c_name, c_type in value_cols 1956 ] 1957 ) 1958 on_duplicate_key_update = ( 1959 "ON DUPLICATE KEY UPDATE" 1960 if value_cols 1961 else "" 1962 ) 1963 ignore = "IGNORE " if not value_cols else "" 1964 1965 formatted_queries = [ 1966 textwrap.dedent(base_query.format( 1967 sets_subquery_none=sets_subquery('', 'p.'), 1968 sets_subquery_none_excluded=sets_subquery('', 'EXCLUDED.'), 1969 sets_subquery_f=sets_subquery('f.', 'p.'), 1970 and_subquery_f=and_subquery('p.', 'f.'), 1971 and_subquery_t=and_subquery('p.', 't.'), 1972 target_table_name=target_table_name, 1973 patch_table_name=patch_table_name, 1974 patch_cols_str=patch_cols_str, 1975 patch_cols_prefixed_str=patch_cols_prefixed_str, 1976 date_bounds_subquery=date_bounds_subquery, 1977 join_cols_str=join_cols_str, 1978 coalesce_join_cols_str=coalesce_join_cols_str, 1979 update_or_nothing=update_or_nothing, 1980 when_matched_update_sets_subquery_none=when_matched_update_sets_subquery_none, 1981 cols_equal_values=cols_equal_values, 1982 on_duplicate_key_update=on_duplicate_key_update, 1983 ignore=ignore, 1984 with_temp_date_bounds=with_temp_date_bounds, 1985 identity_insert_on=identity_insert_on, 1986 identity_insert_off=identity_insert_off, 1987 )).lstrip().rstrip() 1988 for base_query in base_queries 1989 ] 1990 1991 ### NOTE: Allow for skipping some queries. 1992 return [query for query in formatted_queries if query]
Build a list of MERGE, UPDATE, DELETE/INSERT queries to apply a patch to target table.
Parameters
- target (str): The name of the target table.
- patch (str): The name of the patch table. This should have the same shape as the target.
- connectable (Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session]):
The
SQLConnectoror SQLAlchemy session which will later execute the queries. - join_cols (List[str]): The columns to use to join the patch to the target.
- flavor (Optional[str], default None): If using a SQLAlchemy session, provide the expected database flavor.
- upsert (bool, default False):
If
True, return an upsert query rather than an update. - datetime_col (Optional[str], default None): If provided, bound the join query using this column as the datetime index. This must be present on both tables.
- schema (Optional[str], default None):
If provided, use this schema when quoting the target table.
Defaults to
connector.schema. - patch_schema (Optional[str], default None):
If provided, use this schema when quoting the patch table.
Defaults to
schema. - target_cols_types (Optional[Dict[str, Any]], default None): If provided, use these as the columns-types dictionary for the target table. Default will infer from the database context.
- patch_cols_types (Optional[Dict[str, Any]], default None): If provided, use these as the columns-types dictionary for the target table. Default will infer from the database context.
- identity_insert (bool, default False):
If
True, includeSET IDENTITY_INSERTqueries before and after the update queries. Only applies for MSSQL upserts. - null_indices (bool, default True):
If
False, do not coalesce index columns before joining. - cast_columns (bool, default True):
If
False, do not cast update columns to the target table types. - debug (bool, default False): Verbosity toggle.
Returns
- A list of query strings to perform the update operation.
1995def get_null_replacement(typ: str, flavor: str) -> str: 1996 """ 1997 Return a value that may temporarily be used in place of NULL for this type. 1998 1999 Parameters 2000 ---------- 2001 typ: str 2002 The typ to be converted to NULL. 2003 2004 flavor: str 2005 The database flavor for which this value will be used. 2006 2007 Returns 2008 ------- 2009 A value which may stand in place of NULL for this type. 2010 `'None'` is returned if a value cannot be determined. 2011 """ 2012 from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES 2013 if 'geometry' in typ.lower(): 2014 return "'010100000000008058346FCDC100008058346FCDC1'" 2015 if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'): 2016 return '-987654321' 2017 if 'bool' in typ.lower() or typ.lower() == 'bit': 2018 bool_typ = ( 2019 PD_TO_DB_DTYPES_FLAVORS 2020 .get('bool', {}) 2021 .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default']) 2022 ) 2023 if flavor in DB_FLAVORS_CAST_DTYPES: 2024 bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ) 2025 val_to_cast = ( 2026 -987654321 2027 if flavor in ('mysql', 'mariadb') 2028 else 0 2029 ) 2030 return f'CAST({val_to_cast} AS {bool_typ})' 2031 if 'time' in typ.lower() or 'date' in typ.lower(): 2032 db_type = typ if typ.isupper() else None 2033 return dateadd_str(flavor=flavor, begin='1900-01-01', db_type=db_type) 2034 if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',): 2035 return '-987654321.0' 2036 if flavor == 'oracle' and typ.lower().split('(', maxsplit=1)[0] == 'char': 2037 return "'-987654321'" 2038 if flavor == 'oracle' and typ.lower() in ('blob', 'bytes'): 2039 return '00' 2040 if typ.lower() in ('uniqueidentifier', 'guid', 'uuid'): 2041 magic_val = 'DEADBEEF-ABBA-BABE-CAFE-DECAFC0FFEE5' 2042 if flavor == 'mssql': 2043 return f"CAST('{magic_val}' AS UNIQUEIDENTIFIER)" 2044 return f"'{magic_val}'" 2045 return ('n' if flavor == 'oracle' else '') + "'-987654321'"
Return a value that may temporarily be used in place of NULL for this type.
Parameters
- typ (str): The typ to be converted to NULL.
- flavor (str): The database flavor for which this value will be used.
Returns
- A value which may stand in place of NULL for this type.
'None'is returned if a value cannot be determined.
2048def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]: 2049 """ 2050 Fetch the database version if possible. 2051 """ 2052 version_name = sql_item_name('version', conn.flavor, None) 2053 version_query = version_queries.get( 2054 conn.flavor, 2055 version_queries['default'] 2056 ).format(version_name=version_name) 2057 return conn.value(version_query, debug=debug)
Fetch the database version if possible.
2060def get_rename_table_queries( 2061 old_table: str, 2062 new_table: str, 2063 flavor: str, 2064 schema: Optional[str] = None, 2065 new_schema: Optional[str] = None, 2066) -> List[str]: 2067 """ 2068 Return queries to alter a table's name. 2069 2070 Parameters 2071 ---------- 2072 old_table: str 2073 The unquoted name of the old table. 2074 2075 new_table: str 2076 The unquoted name of the new table. 2077 2078 flavor: str 2079 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 2080 2081 schema: Optional[str], default None 2082 The schema on which the table resides. 2083 2084 new_schema: Optional[str], default None 2085 If provided, move the table to this schema. 2086 Defaults to `schema`. 2087 2088 Returns 2089 ------- 2090 A list of `ALTER TABLE` or equivalent queries for the database flavor. 2091 """ 2092 if new_schema is None: 2093 new_schema = schema 2094 2095 if flavor in ('sqlite', 'geopackage'): 2096 schema = None 2097 new_schema = None 2098 2099 old_table_name = sql_item_name(old_table, flavor, schema) 2100 new_table_name = sql_item_name(new_table, flavor, None) 2101 2102 if flavor == 'mssql': 2103 queries = [] 2104 if schema != new_schema: 2105 queries.append(f"ALTER SCHEMA {sql_item_name(new_schema, flavor)} TRANSFER {old_table_name}") 2106 schema = new_schema 2107 old_table_name = sql_item_name(old_table, flavor, schema) 2108 if old_table != new_table: 2109 queries.append(f"EXEC sp_rename '{old_table_name}', '{new_table}'") 2110 return queries 2111 2112 if flavor in ('mysql', 'mariadb'): 2113 new_table_qualified = sql_item_name(new_table, flavor, new_schema) 2114 return [f"RENAME TABLE {old_table_name} TO {new_table_qualified}"] 2115 2116 if_exists_str = "IF EXISTS" if flavor in DROP_IF_EXISTS_FLAVORS else "" 2117 if flavor == 'duckdb': 2118 tmp_table = '_tmp_rename_' + new_table 2119 tmp_table_name = sql_item_name(tmp_table, flavor, schema) 2120 return ( 2121 get_create_table_queries( 2122 f"SELECT * FROM {old_table_name}", 2123 tmp_table, 2124 'duckdb', 2125 schema, 2126 ) + get_create_table_queries( 2127 f"SELECT * FROM {tmp_table_name}", 2128 new_table, 2129 'duckdb', 2130 new_schema, 2131 ) + [ 2132 f"DROP TABLE {if_exists_str} {tmp_table_name}", 2133 f"DROP TABLE {if_exists_str} {old_table_name}", 2134 ] 2135 ) 2136 2137 queries = [] 2138 if schema != new_schema: 2139 if flavor in ('postgresql', 'postgis', 'timescaledb', 'timescaledb-ha', 'citus', 'cockroachdb'): 2140 queries.append(f"ALTER TABLE {old_table_name} SET SCHEMA {sql_item_name(new_schema, flavor)}") 2141 schema = new_schema 2142 old_table_name = sql_item_name(old_table, flavor, schema) 2143 2144 if old_table != new_table: 2145 queries.append(f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}") 2146 2147 return queries
Return queries to alter a table's name.
Parameters
- old_table (str): The unquoted name of the old table.
- new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql','postgresql'. - schema (Optional[str], default None): The schema on which the table resides.
- new_schema (Optional[str], default None):
If provided, move the table to this schema.
Defaults to
schema.
Returns
- A list of
ALTER TABLEor equivalent queries for the database flavor.
2150def get_create_table_query( 2151 query_or_dtypes: Union[str, Dict[str, str]], 2152 new_table: str, 2153 flavor: str, 2154 schema: Optional[str] = None, 2155) -> str: 2156 """ 2157 NOTE: This function is deprecated. Use `get_create_table_queries()` instead. 2158 2159 Return a query to create a new table from a `SELECT` query. 2160 2161 Parameters 2162 ---------- 2163 query: Union[str, Dict[str, str]] 2164 The select query to use for the creation of the table. 2165 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 2166 2167 new_table: str 2168 The unquoted name of the new table. 2169 2170 flavor: str 2171 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 2172 2173 schema: Optional[str], default None 2174 The schema on which the table will reside. 2175 2176 Returns 2177 ------- 2178 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 2179 """ 2180 return get_create_table_queries( 2181 query_or_dtypes, 2182 new_table, 2183 flavor, 2184 schema=schema, 2185 primary_key=None, 2186 )[0]
NOTE: This function is deprecated. Use get_create_table_queries() instead.
Return a query to create a new table from a SELECT query.
Parameters
- query (Union[str, Dict[str, str]]):
The select query to use for the creation of the table.
If a dictionary is provided, return a
CREATE TABLEquery from the givendtypescolumns. - new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql','postgresql'). - schema (Optional[str], default None): The schema on which the table will reside.
Returns
- A
CREATE TABLE(orSELECT INTO) query for the database flavor.
2189def get_create_table_queries( 2190 query_or_dtypes: Union[str, Dict[str, str]], 2191 new_table: str, 2192 flavor: str, 2193 schema: Optional[str] = None, 2194 primary_key: Optional[str] = None, 2195 primary_key_db_type: Optional[str] = None, 2196 autoincrement: bool = False, 2197 datetime_column: Optional[str] = None, 2198 _parse_dtypes: bool = True, 2199) -> List[str]: 2200 """ 2201 Return a query to create a new table from a `SELECT` query or a `dtypes` dictionary. 2202 2203 Parameters 2204 ---------- 2205 query_or_dtypes: Union[str, Dict[str, str]] 2206 The select query to use for the creation of the table. 2207 If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns. 2208 2209 new_table: str 2210 The unquoted name of the new table. 2211 2212 flavor: str 2213 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`). 2214 2215 schema: Optional[str], default None 2216 The schema on which the table will reside. 2217 2218 primary_key: Optional[str], default None 2219 If provided, designate this column as the primary key in the new table. 2220 2221 primary_key_db_type: Optional[str], default None 2222 If provided, alter the primary key to this type (to set NOT NULL constraint). 2223 2224 autoincrement: bool, default False 2225 If `True` and `primary_key` is provided, create the `primary_key` column 2226 as an auto-incrementing integer column. 2227 2228 datetime_column: Optional[str], default None 2229 If provided, include this column in the primary key. 2230 Applicable to TimescaleDB only. 2231 2232 _parse_dtypes: bool, default True 2233 If `True`, cast Pandas dtypes to SQL dtypes. 2234 Otherwise pass through the given value directly. 2235 2236 Returns 2237 ------- 2238 A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor. 2239 """ 2240 if not isinstance(query_or_dtypes, (str, dict)): 2241 raise TypeError("`query_or_dtypes` must be a query or a dtypes dictionary.") 2242 2243 method = ( 2244 _get_create_table_query_from_cte 2245 if isinstance(query_or_dtypes, str) 2246 else _get_create_table_query_from_dtypes 2247 ) 2248 return method( 2249 query_or_dtypes, 2250 new_table, 2251 flavor, 2252 schema=schema, 2253 primary_key=primary_key, 2254 primary_key_db_type=primary_key_db_type, 2255 autoincrement=(autoincrement and flavor not in SKIP_AUTO_INCREMENT_FLAVORS), 2256 datetime_column=datetime_column, 2257 _parse_dtypes=_parse_dtypes, 2258 )
Return a query to create a new table from a SELECT query or a dtypes dictionary.
Parameters
- query_or_dtypes (Union[str, Dict[str, str]]):
The select query to use for the creation of the table.
If a dictionary is provided, return a
CREATE TABLEquery from the givendtypescolumns. - new_table (str): The unquoted name of the new table.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql','postgresql'). - schema (Optional[str], default None): The schema on which the table will reside.
- primary_key (Optional[str], default None): If provided, designate this column as the primary key in the new table.
- primary_key_db_type (Optional[str], default None): If provided, alter the primary key to this type (to set NOT NULL constraint).
- autoincrement (bool, default False):
If
Trueandprimary_keyis provided, create theprimary_keycolumn as an auto-incrementing integer column. - datetime_column (Optional[str], default None): If provided, include this column in the primary key. Applicable to TimescaleDB only.
- _parse_dtypes (bool, default True):
If
True, cast Pandas dtypes to SQL dtypes. Otherwise pass through the given value directly.
Returns
- A
CREATE TABLE(orSELECT INTO) query for the database flavor.
2542def wrap_query_with_cte( 2543 sub_query: str, 2544 parent_query: str, 2545 flavor: str, 2546 cte_name: str = "src", 2547) -> str: 2548 """ 2549 Wrap a subquery in a CTE and append an encapsulating query. 2550 2551 Parameters 2552 ---------- 2553 sub_query: str 2554 The query to be referenced. This may itself contain CTEs. 2555 Unless `cte_name` is provided, this will be aliased as `src`. 2556 2557 parent_query: str 2558 The larger query to append which references the subquery. 2559 This must not contain CTEs. 2560 2561 flavor: str 2562 The database flavor, e.g. `'mssql'`. 2563 2564 cte_name: str, default 'src' 2565 The CTE alias, defaults to `src`. 2566 2567 Returns 2568 ------- 2569 An encapsulating query which allows you to treat `sub_query` as a temporary table. 2570 2571 Examples 2572 -------- 2573 2574 >>> from meerschaum.utils.sql import wrap_query_with_cte 2575 >>> sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo" 2576 >>> parent_query = "SELECT newval * 3 FROM src" 2577 >>> query = wrap_query_with_cte(sub_query, parent_query, 'mssql') 2578 >>> print(query) 2579 >>> # WITH foo AS (SELECT 1 AS val), 2580 >>> # [src] AS ( 2581 >>> # SELECT (val * 2) AS newval FROM foo 2582 >>> # ) 2583 >>> # SELECT newval * 3 FROM src 2584 2585 """ 2586 import textwrap 2587 sub_query = sub_query.lstrip() 2588 cte_name_quoted = sql_item_name(cte_name, flavor, None) 2589 2590 if flavor in NO_CTE_FLAVORS: 2591 return ( 2592 parent_query 2593 .replace(cte_name_quoted, '--MRSM_SUBQUERY--') 2594 .replace(cte_name, '--MRSM_SUBQUERY--') 2595 .replace('--MRSM_SUBQUERY--', f"(\n{sub_query}\n) AS {cte_name_quoted}") 2596 ) 2597 2598 if sub_query.lstrip().lower().startswith('with '): 2599 depth = 0 2600 select_index = -1 2601 sq_lower = sub_query.lower() 2602 2603 # Iterate through the query to find the first 'SELECT' at the top level (depth 0) 2604 # Start searching after the 'WITH' keyword 2605 start_search = sq_lower.find('with') + 4 2606 2607 for i in range(start_search, len(sq_lower)): 2608 char = sq_lower[i] 2609 if char == '(': 2610 depth += 1 2611 elif char == ')': 2612 depth -= 1 2613 elif depth == 0: 2614 # Check for 'SELECT' at a word boundary 2615 if sq_lower[i:i+6] == 'select': 2616 # Ensure it's not part of another word (e.g., 'selection') 2617 # by checking the character immediately following 'select' 2618 is_bound = (i + 6 == len(sq_lower)) or (not sq_lower[i+6].isalnum()) 2619 if is_bound: 2620 select_index = i 2621 break 2622 2623 # If we found the main SELECT, we slice and flatten. 2624 # Part 1 (definitions) contains the 'WITH cte AS (...),' 2625 # Part 2 (body) contains the 'SELECT ... UNION ALL ...' 2626 if select_index != -1: 2627 definitions = sub_query[:select_index].rstrip() 2628 # If the definitions end in a comma (rare but possible), remove it 2629 if definitions.endswith(','): 2630 definitions = definitions[:-1].rstrip() 2631 2632 body = sub_query[select_index:].strip() 2633 2634 return ( 2635 f"{definitions},\n" 2636 f"{cte_name_quoted} AS (\n" 2637 f"{textwrap.indent(body, ' ')}\n" 2638 f")\n{parent_query}" 2639 ) 2640 2641 return ( 2642 f"WITH {cte_name_quoted} AS (\n" 2643 f"{textwrap.indent(sub_query, ' ')}\n" 2644 f")\n{parent_query}" 2645 )
Wrap a subquery in a CTE and append an encapsulating query.
Parameters
- sub_query (str):
The query to be referenced. This may itself contain CTEs.
Unless
cte_nameis provided, this will be aliased assrc. - parent_query (str): The larger query to append which references the subquery. This must not contain CTEs.
- flavor (str):
The database flavor, e.g.
'mssql'. - cte_name (str, default 'src'):
The CTE alias, defaults to
src.
Returns
- An encapsulating query which allows you to treat
sub_queryas a temporary table.
Examples
>>> from meerschaum.utils.sql import wrap_query_with_cte
>>> sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo"
>>> parent_query = "SELECT newval * 3 FROM src"
>>> query = wrap_query_with_cte(sub_query, parent_query, 'mssql')
>>> print(query)
>>> # WITH foo AS (SELECT 1 AS val),
>>> # [src] AS (
>>> # SELECT (val * 2) AS newval FROM foo
>>> # )
>>> # SELECT newval * 3 FROM src
2648def format_cte_subquery( 2649 sub_query: str, 2650 flavor: str, 2651 sub_name: str = 'src', 2652 cols_to_select: Union[List[str], str] = '*', 2653) -> str: 2654 """ 2655 Given a subquery, build a wrapper query that selects from the CTE subquery. 2656 2657 Parameters 2658 ---------- 2659 sub_query: str 2660 The subquery to wrap. 2661 2662 flavor: str 2663 The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`. 2664 2665 sub_name: str, default 'src' 2666 If possible, give this name to the CTE (must be unquoted). 2667 2668 cols_to_select: Union[List[str], str], default '' 2669 If specified, choose which columns to select from the CTE. 2670 If a list of strings is provided, each item will be quoted and joined with commas. 2671 If a string is given, assume it is quoted and insert it into the query. 2672 2673 Returns 2674 ------- 2675 A wrapper query that selects from the CTE. 2676 """ 2677 quoted_sub_name = sql_item_name(sub_name, flavor, None) 2678 cols_str = ( 2679 cols_to_select 2680 if isinstance(cols_to_select, str) 2681 else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select]) 2682 ) 2683 parent_query = ( 2684 f"SELECT {cols_str}\n" 2685 f"FROM {quoted_sub_name}" 2686 ) 2687 return wrap_query_with_cte(sub_query, parent_query, flavor, cte_name=sub_name)
Given a subquery, build a wrapper query that selects from the CTE subquery.
Parameters
- sub_query (str): The subquery to wrap.
- flavor (str):
The database flavor to use for the query (e.g.
'mssql','postgresql'. - sub_name (str, default 'src'): If possible, give this name to the CTE (must be unquoted).
- cols_to_select (Union[List[str], str], default ''): If specified, choose which columns to select from the CTE. If a list of strings is provided, each item will be quoted and joined with commas. If a string is given, assume it is quoted and insert it into the query.
Returns
- A wrapper query that selects from the CTE.
2690def session_execute( 2691 session: 'sqlalchemy.orm.session.Session', 2692 queries: Union[List[str], str], 2693 with_results: bool = False, 2694 debug: bool = False, 2695) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]: 2696 """ 2697 Similar to `SQLConnector.exec_queries()`, execute a list of queries 2698 and roll back when one fails. 2699 2700 Parameters 2701 ---------- 2702 session: sqlalchemy.orm.session.Session 2703 A SQLAlchemy session representing a transaction. 2704 2705 queries: Union[List[str], str] 2706 A query or list of queries to be executed. 2707 If a query fails, roll back the session. 2708 2709 with_results: bool, default False 2710 If `True`, return a list of result objects. 2711 2712 Returns 2713 ------- 2714 A `SuccessTuple` indicating the queries were successfully executed. 2715 If `with_results`, return the `SuccessTuple` and a list of results. 2716 """ 2717 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 2718 if not isinstance(queries, list): 2719 queries = [queries] 2720 successes, msgs, results = [], [], [] 2721 for query in queries: 2722 if debug: 2723 dprint(query) 2724 query_text = sqlalchemy.text(query) 2725 fail_msg = "Failed to execute queries." 2726 try: 2727 result = session.execute(query_text) 2728 query_success = result is not None 2729 query_msg = "Success" if query_success else fail_msg 2730 except Exception as e: 2731 query_success = False 2732 query_msg = f"{fail_msg}\n{e}" 2733 result = None 2734 successes.append(query_success) 2735 msgs.append(query_msg) 2736 results.append(result) 2737 if not query_success: 2738 if debug: 2739 dprint("Rolling back session.") 2740 session.rollback() 2741 break 2742 success, msg = all(successes), '\n'.join(msgs) 2743 if with_results: 2744 return (success, msg), results 2745 return success, msg
Similar to SQLConnector.exec_queries(), execute a list of queries
and roll back when one fails.
Parameters
- session (sqlalchemy.orm.session.Session): A SQLAlchemy session representing a transaction.
- queries (Union[List[str], str]): A query or list of queries to be executed. If a query fails, roll back the session.
- with_results (bool, default False):
If
True, return a list of result objects.
Returns
- A
SuccessTupleindicating the queries were successfully executed. - If
with_results, return theSuccessTupleand a list of results.
2748def get_reset_autoincrement_queries( 2749 table: str, 2750 column: str, 2751 connector: mrsm.connectors.SQLConnector, 2752 schema: Optional[str] = None, 2753 debug: bool = False, 2754) -> List[str]: 2755 """ 2756 Return a list of queries to reset a table's auto-increment counter to the next largest value. 2757 2758 Parameters 2759 ---------- 2760 table: str 2761 The name of the table on which the auto-incrementing column exists. 2762 2763 column: str 2764 The name of the auto-incrementing column. 2765 2766 connector: mrsm.connectors.SQLConnector 2767 The SQLConnector to the database on which the table exists. 2768 2769 schema: Optional[str], default None 2770 The schema of the table. Defaults to `connector.schema`. 2771 2772 Returns 2773 ------- 2774 A list of queries to be executed to reset the auto-incrementing column. 2775 """ 2776 if not table_exists(table, connector, schema=schema, debug=debug): 2777 return [] 2778 2779 schema = schema or connector.schema 2780 max_id_name = sql_item_name('max_id', connector.flavor) 2781 table_name = sql_item_name(table, connector.flavor, schema) 2782 table_seq_name = sql_item_name(table + '_' + column + '_seq', connector.flavor, schema) 2783 column_name = sql_item_name(column, connector.flavor) 2784 max_id = connector.value( 2785 f""" 2786 SELECT COALESCE(MAX({column_name}), 0) AS {max_id_name} 2787 FROM {table_name} 2788 """, 2789 debug=debug, 2790 ) 2791 if max_id is None: 2792 return [] 2793 2794 reset_queries = reset_autoincrement_queries.get( 2795 connector.flavor, 2796 reset_autoincrement_queries['default'] 2797 ) 2798 if not isinstance(reset_queries, list): 2799 reset_queries = [reset_queries] 2800 2801 return [ 2802 query.format( 2803 column=column, 2804 column_name=column_name, 2805 table=table, 2806 table_name=table_name, 2807 table_seq_name=table_seq_name, 2808 val=max_id, 2809 val_plus_1=(max_id + 1), 2810 ) 2811 for query in reset_queries 2812 ]
Return a list of queries to reset a table's auto-increment counter to the next largest value.
Parameters
- table (str): The name of the table on which the auto-incrementing column exists.
- column (str): The name of the auto-incrementing column.
- connector (mrsm.connectors.SQLConnector): The SQLConnector to the database on which the table exists.
- schema (Optional[str], default None):
The schema of the table. Defaults to
connector.schema.
Returns
- A list of queries to be executed to reset the auto-incrementing column.
2815def get_postgis_geo_columns_types( 2816 connectable: Union[ 2817 'mrsm.connectors.sql.SQLConnector', 2818 'sqlalchemy.orm.session.Session', 2819 'sqlalchemy.engine.base.Engine' 2820 ], 2821 table: str, 2822 schema: Optional[str] = 'public', 2823 debug: bool = False, 2824) -> Dict[str, str]: 2825 """ 2826 Return a dictionary mapping PostGIS geometry column names to geometry types. 2827 2828 Parameters 2829 ---------- 2830 connectable: Union[SQLConnector, sqlalchemy.orm.session.Session, sqlalchemy.engine.base.Engine] 2831 The SQLConnector, Session, or Engine. 2832 2833 table: str 2834 The table's name. 2835 2836 schema: Optional[str], default 'public' 2837 The schema to use for the fully qualified table name. 2838 2839 Returns 2840 ------- 2841 A dictionary mapping column names to geometry types. 2842 """ 2843 from meerschaum.utils.dtypes import get_geometry_type_srid 2844 sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False) 2845 default_type, default_srid = get_geometry_type_srid() 2846 default_type = default_type.upper() 2847 2848 clean(table) 2849 clean(str(schema)) 2850 schema = schema or 'public' 2851 truncated_schema_name = truncate_item_name(schema, flavor='postgis') 2852 truncated_table_name = truncate_item_name(table, flavor='postgis') 2853 query = sqlalchemy.text( 2854 "SELECT \"f_geometry_column\" AS \"column\", 'GEOMETRY' AS \"func\", \"type\", geom.\"srid\"\n" 2855 "FROM \"geometry_columns\" AS geom\n" 2856 f"WHERE \"f_table_schema\" = '{truncated_schema_name}'\n" 2857 f" AND \"f_table_name\" = '{truncated_table_name}'\n" 2858 "UNION ALL\n" 2859 "SELECT \"f_geography_column\" AS \"column\", 'GEOGRAPHY' AS \"func\", \"type\", geog.\"srid\"\n" 2860 "FROM \"geography_columns\" AS geog\n" 2861 f"WHERE \"f_table_schema\" = '{truncated_schema_name}'\n" 2862 f" AND \"f_table_name\" = '{truncated_table_name}'\n" 2863 ) 2864 debug_kwargs = {'debug': debug} if isinstance(connectable, mrsm.connectors.SQLConnector) else {} 2865 result_rows = [ 2866 row 2867 for row in connectable.execute(query, **debug_kwargs).fetchall() 2868 ] 2869 cols_type_tuples = { 2870 row[0]: ( 2871 row[1], # func 2872 row[2], # type 2873 ( 2874 row[3] 2875 if row[3] 2876 else 0 2877 ) # srid 2878 ) 2879 for row in result_rows 2880 } 2881 2882 geometry_cols_types = { 2883 col: ( 2884 f"{func}({typ.upper()}, {srid})" 2885 if srid 2886 else ( 2887 func 2888 + ( 2889 f'({typ.upper()})' 2890 if typ.upper() not in ('GEOMETRY', 'GEOGRAPHY') 2891 else '' 2892 ) 2893 ) 2894 ) 2895 for col, (func, typ, srid) in cols_type_tuples.items() 2896 } 2897 return geometry_cols_types
Return a dictionary mapping PostGIS geometry column names to geometry types.
Parameters
- connectable (Union[SQLConnector, sqlalchemy.orm.session.Session, sqlalchemy.engine.base.Engine]): The SQLConnector, Session, or Engine.
- table (str): The table's name.
- schema (Optional[str], default 'public'): The schema to use for the fully qualified table name.
Returns
- A dictionary mapping column names to geometry types.
2900def get_create_schema_if_not_exists_queries( 2901 schema: str, 2902 flavor: str, 2903) -> List[str]: 2904 """ 2905 Return the queries to create a schema if it does not yet exist. 2906 For databases which do not support schemas, an empty list will be returned. 2907 """ 2908 if not schema: 2909 return [] 2910 2911 if flavor in NO_SCHEMA_FLAVORS: 2912 return [] 2913 2914 if schema == DEFAULT_SCHEMA_FLAVORS.get(flavor, None): 2915 return [] 2916 2917 clean(schema) 2918 2919 if flavor == 'mssql': 2920 return [ 2921 ( 2922 f"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{schema}')\n" 2923 "BEGIN\n" 2924 f" EXEC('CREATE SCHEMA {sql_item_name(schema, flavor)}');\n" 2925 "END;" 2926 ) 2927 ] 2928 2929 if flavor == 'oracle': 2930 return [] 2931 2932 return [ 2933 f"CREATE SCHEMA IF NOT EXISTS {sql_item_name(schema, flavor)};" 2934 ]
Return the queries to create a schema if it does not yet exist. For databases which do not support schemas, an empty list will be returned.