meerschaum.utils.sql

Flavor-specific SQL tools.

   1#! /usr/bin/env python
   2# -*- coding: utf-8 -*-
   3# vim:fenc=utf-8
   4
   5"""
   6Flavor-specific SQL tools.
   7"""
   8
   9from __future__ import annotations
  10
  11from datetime import datetime, timezone, timedelta
  12import meerschaum as mrsm
  13from meerschaum.utils.typing import Optional, Dict, Any, Union, List, Iterable, Tuple
  14### Preserve legacy imports.
  15from meerschaum.utils.dtypes.sql import (
  16    DB_TO_PD_DTYPES,
  17    PD_TO_DB_DTYPES_FLAVORS,
  18    get_pd_type_from_db_type as get_pd_type,
  19    get_db_type_from_pd_type as get_db_type,
  20    TIMEZONE_NAIVE_FLAVORS,
  21)
  22from meerschaum.utils.warnings import warn
  23from meerschaum.utils.debug import dprint
  24
  25test_queries = {
  26    'default'    : 'SELECT 1',
  27    'oracle'     : 'SELECT 1 FROM DUAL',
  28    'informix'   : 'SELECT COUNT(*) FROM systables',
  29    'hsqldb'     : 'SELECT 1 FROM INFORMATION_SCHEMA.SYSTEM_USERS',
  30}
  31### `table_name` is the escaped name of the table.
  32### `table` is the unescaped name of the table.
  33exists_queries = {
  34    'default': "SELECT COUNT(*) FROM {table_name} WHERE 1 = 0",
  35}
  36version_queries = {
  37    'default': "SELECT VERSION() AS {version_name}",
  38    'sqlite': "SELECT SQLITE_VERSION() AS {version_name}",
  39    'mssql': "SELECT @@version",
  40    'oracle': "SELECT version from PRODUCT_COMPONENT_VERSION WHERE rownum = 1",
  41}
  42SKIP_IF_EXISTS_FLAVORS = {'mssql', 'oracle'}
  43DROP_IF_EXISTS_FLAVORS = {
  44    'timescaledb', 'postgresql', 'citus', 'mssql', 'mysql', 'mariadb', 'sqlite',
  45}
  46DROP_INDEX_IF_EXISTS_FLAVORS = {
  47    'mssql', 'timescaledb', 'postgresql', 'sqlite', 'citus',
  48}
  49SKIP_AUTO_INCREMENT_FLAVORS = {'citus', 'duckdb'}
  50COALESCE_UNIQUE_INDEX_FLAVORS = {'timescaledb', 'postgresql', 'citus'}
  51UPDATE_QUERIES = {
  52    'default': """
  53        UPDATE {target_table_name} AS f
  54        {sets_subquery_none}
  55        FROM {target_table_name} AS t
  56        INNER JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p
  57            ON
  58                {and_subquery_t}
  59        WHERE
  60            {and_subquery_f}
  61            AND
  62            {date_bounds_subquery}
  63    """,
  64    'timescaledb-upsert': """
  65        INSERT INTO {target_table_name} ({patch_cols_str})
  66        SELECT {patch_cols_str}
  67        FROM {patch_table_name}
  68        ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded}
  69    """,
  70    'postgresql-upsert': """
  71        INSERT INTO {target_table_name} ({patch_cols_str})
  72        SELECT {patch_cols_str}
  73        FROM {patch_table_name}
  74        ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded}
  75    """,
  76    'citus-upsert': """
  77        INSERT INTO {target_table_name} ({patch_cols_str})
  78        SELECT {patch_cols_str}
  79        FROM {patch_table_name}
  80        ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded}
  81    """,
  82    'cockroachdb-upsert': """
  83        INSERT INTO {target_table_name} ({patch_cols_str})
  84        SELECT {patch_cols_str}
  85        FROM {patch_table_name}
  86        ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded}
  87    """,
  88    'mysql': """
  89        UPDATE {target_table_name} AS f
  90        JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p
  91        ON
  92            {and_subquery_f}
  93        {sets_subquery_f}
  94        WHERE
  95            {date_bounds_subquery}
  96    """,
  97    'mysql-upsert': """
  98        INSERT {ignore}INTO {target_table_name} ({patch_cols_str})
  99        SELECT {patch_cols_str}
 100        FROM {patch_table_name}
 101        {on_duplicate_key_update}
 102            {cols_equal_values}
 103    """,
 104    'mariadb': """
 105        UPDATE {target_table_name} AS f
 106        JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p
 107        ON
 108            {and_subquery_f}
 109        {sets_subquery_f}
 110        WHERE
 111            {date_bounds_subquery}
 112    """,
 113    'mariadb-upsert': """
 114        INSERT {ignore}INTO {target_table_name} ({patch_cols_str})
 115        SELECT {patch_cols_str}
 116        FROM {patch_table_name}
 117        {on_duplicate_key_update}
 118            {cols_equal_values}
 119    """,
 120    'mssql': """
 121        {with_temp_date_bounds}
 122        MERGE {target_table_name} f
 123            USING (SELECT {patch_cols_str} FROM {patch_table_name}) p
 124            ON
 125                {and_subquery_f}
 126            AND
 127                {date_bounds_subquery}
 128        WHEN MATCHED THEN
 129            UPDATE
 130            {sets_subquery_none};
 131    """,
 132    'mssql-upsert': [
 133        "{identity_insert_on}",
 134        """
 135        {with_temp_date_bounds}
 136        MERGE {target_table_name} f
 137            USING (SELECT {patch_cols_str} FROM {patch_table_name}) p
 138            ON
 139                {and_subquery_f}
 140            AND
 141                {date_bounds_subquery}{when_matched_update_sets_subquery_none}
 142        WHEN NOT MATCHED THEN
 143            INSERT ({patch_cols_str})
 144            VALUES ({patch_cols_prefixed_str});
 145        """,
 146        "{identity_insert_off}",
 147    ],
 148    'oracle': """
 149        MERGE INTO {target_table_name} f
 150            USING (SELECT {patch_cols_str} FROM {patch_table_name}) p
 151            ON (
 152                {and_subquery_f}
 153                AND
 154                {date_bounds_subquery}
 155            )
 156            WHEN MATCHED THEN
 157                UPDATE
 158                {sets_subquery_none}
 159    """,
 160    'oracle-upsert': """
 161        MERGE INTO {target_table_name} f
 162            USING (SELECT {patch_cols_str} FROM {patch_table_name}) p
 163            ON (
 164                {and_subquery_f}
 165                AND
 166                {date_bounds_subquery}
 167            ){when_matched_update_sets_subquery_none}
 168            WHEN NOT MATCHED THEN
 169                INSERT ({patch_cols_str})
 170                VALUES ({patch_cols_prefixed_str})
 171    """,
 172    'sqlite-upsert': """
 173        INSERT INTO {target_table_name} ({patch_cols_str})
 174        SELECT {patch_cols_str}
 175        FROM {patch_table_name}
 176        WHERE true
 177        ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded}
 178    """,
 179    'sqlite_delete_insert': [
 180        """
 181        DELETE FROM {target_table_name} AS f
 182        WHERE ROWID IN (
 183            SELECT t.ROWID
 184            FROM {target_table_name} AS t
 185            INNER JOIN (SELECT * FROM {patch_table_name}) AS p
 186                ON {and_subquery_t}
 187        );
 188        """,
 189        """
 190        INSERT INTO {target_table_name} AS f
 191        SELECT {patch_cols_str} FROM {patch_table_name} AS p
 192        """,
 193    ],
 194}
 195columns_types_queries = {
 196    'default': """
 197        SELECT
 198            table_catalog AS database,
 199            table_schema AS schema,
 200            table_name AS table,
 201            column_name AS column,
 202            data_type AS type,
 203            numeric_precision,
 204            numeric_scale
 205        FROM information_schema.columns
 206        WHERE table_name IN ('{table}', '{table_trunc}')
 207    """,
 208    'sqlite': """
 209        SELECT
 210            '' "database",
 211            '' "schema",
 212            m.name "table",
 213            p.name "column",
 214            p.type "type"
 215        FROM sqlite_master m
 216        LEFT OUTER JOIN pragma_table_info(m.name) p
 217            ON m.name <> p.name
 218        WHERE m.type = 'table'
 219            AND m.name IN ('{table}', '{table_trunc}')
 220    """,
 221    'mssql': """
 222        SELECT
 223            TABLE_CATALOG AS [database],
 224            TABLE_SCHEMA AS [schema],
 225            TABLE_NAME AS [table],
 226            COLUMN_NAME AS [column],
 227            DATA_TYPE AS [type],
 228            NUMERIC_PRECISION AS [numeric_precision],
 229            NUMERIC_SCALE AS [numeric_scale]
 230        FROM {db_prefix}INFORMATION_SCHEMA.COLUMNS
 231        WHERE TABLE_NAME IN (
 232            '{table}',
 233            '{table_trunc}'
 234        )
 235
 236    """,
 237    'mysql': """
 238        SELECT
 239            TABLE_SCHEMA `database`,
 240            TABLE_SCHEMA `schema`,
 241            TABLE_NAME `table`,
 242            COLUMN_NAME `column`,
 243            DATA_TYPE `type`,
 244            NUMERIC_PRECISION `numeric_precision`,
 245            NUMERIC_SCALE `numeric_scale`
 246        FROM INFORMATION_SCHEMA.COLUMNS
 247        WHERE TABLE_NAME IN ('{table}', '{table_trunc}')
 248    """,
 249    'mariadb': """
 250        SELECT
 251            TABLE_SCHEMA `database`,
 252            TABLE_SCHEMA `schema`,
 253            TABLE_NAME `table`,
 254            COLUMN_NAME `column`,
 255            DATA_TYPE `type`,
 256            NUMERIC_PRECISION `numeric_precision`,
 257            NUMERIC_SCALE `numeric_scale`
 258        FROM INFORMATION_SCHEMA.COLUMNS
 259        WHERE TABLE_NAME IN ('{table}', '{table_trunc}')
 260    """,
 261    'oracle': """
 262        SELECT
 263            NULL AS "database",
 264            NULL AS "schema",
 265            TABLE_NAME AS "table",
 266            COLUMN_NAME AS "column",
 267            DATA_TYPE AS "type",
 268            DATA_PRECISION AS "numeric_precision",
 269            DATA_SCALE AS "numeric_scale"
 270        FROM all_tab_columns
 271        WHERE TABLE_NAME IN (
 272            '{table}',
 273            '{table_trunc}',
 274            '{table_lower}',
 275            '{table_lower_trunc}',
 276            '{table_upper}',
 277            '{table_upper_trunc}'
 278        )
 279    """,
 280}
 281hypertable_queries = {
 282    'timescaledb': 'SELECT hypertable_size(\'{table_name}\')',
 283    'citus': 'SELECT citus_table_size(\'{table_name}\')',
 284}
 285columns_indices_queries = {
 286    'default': """
 287        SELECT
 288            current_database() AS "database",
 289            n.nspname AS "schema",
 290            t.relname AS "table",
 291            c.column_name AS "column",
 292            i.relname AS "index",
 293            CASE WHEN con.contype = 'p' THEN 'PRIMARY KEY' ELSE 'INDEX' END AS "index_type"
 294        FROM pg_class t
 295        INNER JOIN pg_index AS ix
 296            ON t.oid = ix.indrelid
 297        INNER JOIN pg_class AS i
 298            ON i.oid = ix.indexrelid
 299        INNER JOIN pg_namespace AS n
 300            ON n.oid = t.relnamespace
 301        INNER JOIN pg_attribute AS a
 302            ON a.attnum = ANY(ix.indkey)
 303            AND a.attrelid = t.oid
 304        INNER JOIN information_schema.columns AS c
 305            ON c.column_name = a.attname
 306            AND c.table_name = t.relname
 307            AND c.table_schema = n.nspname
 308        LEFT JOIN pg_constraint AS con
 309            ON con.conindid = i.oid
 310            AND con.contype = 'p'
 311        WHERE
 312            t.relname IN ('{table}', '{table_trunc}')
 313            AND n.nspname = '{schema}'
 314    """,
 315    'sqlite': """
 316        WITH indexed_columns AS (
 317            SELECT
 318                '{table}' AS table_name,
 319                pi.name AS column_name,
 320                i.name AS index_name,
 321                'INDEX' AS index_type
 322            FROM
 323                sqlite_master AS i,
 324                pragma_index_info(i.name) AS pi
 325            WHERE
 326                i.type = 'index'
 327                AND i.tbl_name = '{table}'
 328        ),
 329        primary_key_columns AS (
 330            SELECT
 331                '{table}' AS table_name,
 332                ti.name AS column_name,
 333                'PRIMARY_KEY' AS index_name,
 334                'PRIMARY KEY' AS index_type
 335            FROM
 336                pragma_table_info('{table}') AS ti
 337            WHERE
 338                ti.pk > 0
 339        )
 340        SELECT
 341            NULL AS "database",
 342            NULL AS "schema",
 343            "table_name" AS "table",
 344            "column_name" AS "column",
 345            "index_name" AS "index",
 346            "index_type"
 347        FROM indexed_columns
 348        UNION ALL
 349        SELECT
 350            NULL AS "database",
 351            NULL AS "schema",
 352            table_name AS "table",
 353            column_name AS "column",
 354            index_name AS "index",
 355            index_type
 356        FROM primary_key_columns
 357    """,
 358    'mssql': """
 359        SELECT
 360            NULL AS [database],
 361            s.name AS [schema],
 362            t.name AS [table],
 363            c.name AS [column],
 364            i.name AS [index],
 365            CASE
 366                WHEN kc.type = 'PK' THEN 'PRIMARY KEY'
 367                ELSE 'INDEX'
 368            END AS [index_type],
 369            CASE
 370                WHEN i.type = 1 THEN CAST(1 AS BIT)
 371                ELSE CAST(0 AS BIT)
 372            END AS [clustered]
 373        FROM
 374            sys.schemas s
 375        INNER JOIN sys.tables t
 376            ON s.schema_id = t.schema_id
 377        INNER JOIN sys.indexes i
 378            ON t.object_id = i.object_id
 379        INNER JOIN sys.index_columns ic
 380            ON i.object_id = ic.object_id
 381            AND i.index_id = ic.index_id
 382        INNER JOIN sys.columns c
 383            ON ic.object_id = c.object_id
 384            AND ic.column_id = c.column_id
 385        LEFT JOIN sys.key_constraints kc
 386            ON kc.parent_object_id = i.object_id
 387            AND kc.type = 'PK'
 388            AND kc.name = i.name
 389        WHERE
 390            t.name IN ('{table}', '{table_trunc}')
 391            AND s.name = '{schema}'
 392            AND i.type IN (1, 2)
 393    """,
 394    'oracle': """
 395        SELECT
 396            NULL AS "database",
 397            ic.table_owner AS "schema",
 398            ic.table_name AS "table",
 399            ic.column_name AS "column",
 400            i.index_name AS "index",
 401            CASE
 402                WHEN c.constraint_type = 'P' THEN 'PRIMARY KEY'
 403                WHEN i.uniqueness = 'UNIQUE' THEN 'UNIQUE INDEX'
 404                ELSE 'INDEX'
 405            END AS index_type
 406        FROM
 407            all_ind_columns ic
 408        INNER JOIN all_indexes i
 409            ON ic.index_name = i.index_name
 410            AND ic.table_owner = i.owner
 411        LEFT JOIN all_constraints c
 412            ON i.index_name = c.constraint_name
 413            AND i.table_owner = c.owner
 414            AND c.constraint_type = 'P'
 415        WHERE ic.table_name IN (
 416            '{table}',
 417            '{table_trunc}',
 418            '{table_upper}',
 419            '{table_upper_trunc}'
 420        )
 421    """,
 422    'mysql': """
 423        SELECT
 424            TABLE_SCHEMA AS `database`,
 425            TABLE_SCHEMA AS `schema`,
 426            TABLE_NAME AS `table`,
 427            COLUMN_NAME AS `column`,
 428            INDEX_NAME AS `index`,
 429            CASE
 430                WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY'
 431                ELSE 'INDEX'
 432            END AS `index_type`
 433        FROM
 434            information_schema.STATISTICS
 435        WHERE
 436            TABLE_NAME IN ('{table}', '{table_trunc}')
 437    """,
 438    'mariadb': """
 439        SELECT
 440            TABLE_SCHEMA AS `database`,
 441            TABLE_SCHEMA AS `schema`,
 442            TABLE_NAME AS `table`,
 443            COLUMN_NAME AS `column`,
 444            INDEX_NAME AS `index`,
 445            CASE
 446                WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY'
 447                ELSE 'INDEX'
 448            END AS `index_type`
 449        FROM
 450            information_schema.STATISTICS
 451        WHERE
 452            TABLE_NAME IN ('{table}', '{table_trunc}')
 453    """,
 454}
 455reset_autoincrement_queries: Dict[str, Union[str, List[str]]] = {
 456    'default': """
 457        SELECT SETVAL(pg_get_serial_sequence('{table}', '{column}'), {val})
 458        FROM {table_name}
 459    """,
 460    'mssql': """
 461        DBCC CHECKIDENT ('{table}', RESEED, {val})
 462    """,
 463    'mysql': """
 464        ALTER TABLE {table_name} AUTO_INCREMENT = {val}
 465    """,
 466    'mariadb': """
 467        ALTER TABLE {table_name} AUTO_INCREMENT = {val}
 468    """,
 469    'sqlite': """
 470        UPDATE sqlite_sequence
 471        SET seq = {val}
 472        WHERE name = '{table}'
 473    """,
 474    'oracle': (
 475        "ALTER TABLE {table_name} MODIFY {column_name} "
 476        "GENERATED BY DEFAULT ON NULL AS IDENTITY (START WITH {val_plus_1})"
 477    ),
 478}
 479table_wrappers = {
 480    'default'    : ('"', '"'),
 481    'timescaledb': ('"', '"'),
 482    'citus'      : ('"', '"'),
 483    'duckdb'     : ('"', '"'),
 484    'postgresql' : ('"', '"'),
 485    'sqlite'     : ('"', '"'),
 486    'mysql'      : ('`', '`'),
 487    'mariadb'    : ('`', '`'),
 488    'mssql'      : ('[', ']'),
 489    'cockroachdb': ('"', '"'),
 490    'oracle'     : ('"', '"'),
 491}
 492max_name_lens = {
 493    'default'    : 64,
 494    'mssql'      : 128,
 495    'oracle'     : 30,
 496    'postgresql' : 64,
 497    'timescaledb': 64,
 498    'citus'      : 64,
 499    'cockroachdb': 64,
 500    'sqlite'     : 1024, ### Probably more, but 1024 seems more than reasonable.
 501    'mysql'      : 64,
 502    'mariadb'    : 64,
 503}
 504json_flavors = {'postgresql', 'timescaledb', 'citus', 'cockroachdb'}
 505NO_SCHEMA_FLAVORS = {'oracle', 'sqlite', 'mysql', 'mariadb', 'duckdb'}
 506DEFAULT_SCHEMA_FLAVORS = {
 507    'postgresql': 'public',
 508    'timescaledb': 'public',
 509    'citus': 'public',
 510    'cockroachdb': 'public',
 511    'mysql': 'mysql',
 512    'mariadb': 'mysql',
 513    'mssql': 'dbo',
 514}
 515OMIT_NULLSFIRST_FLAVORS = {'mariadb', 'mysql', 'mssql'}
 516
 517SINGLE_ALTER_TABLE_FLAVORS = {'duckdb', 'sqlite', 'mssql', 'oracle'}
 518NO_CTE_FLAVORS = {'mysql', 'mariadb'}
 519NO_SELECT_INTO_FLAVORS = {'sqlite', 'oracle', 'mysql', 'mariadb', 'duckdb'}
 520
 521
 522def clean(substring: str) -> str:
 523    """
 524    Ensure a substring is clean enough to be inserted into a SQL query.
 525    Raises an exception when banned words are used.
 526    """
 527    from meerschaum.utils.warnings import error
 528    banned_symbols = [';', '--', 'drop ',]
 529    for symbol in banned_symbols:
 530        if symbol in str(substring).lower():
 531            error(f"Invalid string: '{substring}'")
 532
 533
 534def dateadd_str(
 535    flavor: str = 'postgresql',
 536    datepart: str = 'day',
 537    number: Union[int, float] = 0,
 538    begin: Union[str, datetime, int] = 'now',
 539    db_type: Optional[str] = None,
 540) -> str:
 541    """
 542    Generate a `DATEADD` clause depending on database flavor.
 543
 544    Parameters
 545    ----------
 546    flavor: str, default `'postgresql'`
 547        SQL database flavor, e.g. `'postgresql'`, `'sqlite'`.
 548
 549        Currently supported flavors:
 550
 551        - `'postgresql'`
 552        - `'timescaledb'`
 553        - `'citus'`
 554        - `'cockroachdb'`
 555        - `'duckdb'`
 556        - `'mssql'`
 557        - `'mysql'`
 558        - `'mariadb'`
 559        - `'sqlite'`
 560        - `'oracle'`
 561
 562    datepart: str, default `'day'`
 563        Which part of the date to modify. Supported values:
 564
 565        - `'year'`
 566        - `'month'`
 567        - `'day'`
 568        - `'hour'`
 569        - `'minute'`
 570        - `'second'`
 571
 572    number: Union[int, float], default `0`
 573        How many units to add to the date part.
 574
 575    begin: Union[str, datetime], default `'now'`
 576        Base datetime to which to add dateparts.
 577
 578    db_type: Optional[str], default None
 579        If provided, cast the datetime string as the type.
 580        Otherwise, infer this from the input datetime value.
 581
 582    Returns
 583    -------
 584    The appropriate `DATEADD` string for the corresponding database flavor.
 585
 586    Examples
 587    --------
 588    >>> dateadd_str(
 589    ...     flavor='mssql',
 590    ...     begin=datetime(2022, 1, 1, 0, 0),
 591    ...     number=1,
 592    ... )
 593    "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))"
 594    >>> dateadd_str(
 595    ...     flavor='postgresql',
 596    ...     begin=datetime(2022, 1, 1, 0, 0),
 597    ...     number=1,
 598    ... )
 599    "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'"
 600
 601    """
 602    from meerschaum.utils.packages import attempt_import
 603    from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, get_pd_type_from_db_type
 604    dateutil_parser = attempt_import('dateutil.parser')
 605    if 'int' in str(type(begin)).lower():
 606        num_str = str(begin)
 607        if number is not None and number != 0:
 608            num_str += (
 609                f' + {number}'
 610                if number > 0
 611                else f" - {number * -1}"
 612            )
 613        return num_str
 614    if not begin:
 615        return ''
 616
 617    _original_begin = begin
 618    begin_time = None
 619    ### Sanity check: make sure `begin` is a valid datetime before we inject anything.
 620    if not isinstance(begin, datetime):
 621        try:
 622            begin_time = dateutil_parser.parse(begin)
 623        except Exception:
 624            begin_time = None
 625    else:
 626        begin_time = begin
 627
 628    ### Unable to parse into a datetime.
 629    if begin_time is None:
 630        ### Throw an error if banned symbols are included in the `begin` string.
 631        clean(str(begin))
 632    ### If begin is a valid datetime, wrap it in quotes.
 633    else:
 634        if isinstance(begin, datetime) and begin.tzinfo is not None:
 635            begin = begin.astimezone(timezone.utc)
 636        begin = (
 637            f"'{begin.replace(tzinfo=None)}'"
 638            if isinstance(begin, datetime) and flavor in TIMEZONE_NAIVE_FLAVORS
 639            else f"'{begin}'"
 640        )
 641
 642    dt_is_utc = (
 643        begin_time.tzinfo is not None
 644        if begin_time is not None
 645        else ('+' in str(begin) or '-' in str(begin).split(':', maxsplit=1)[-1])
 646    )
 647    if db_type:
 648        db_type_is_utc = 'utc' in get_pd_type_from_db_type(db_type).lower()
 649        dt_is_utc = dt_is_utc or db_type_is_utc
 650    db_type = db_type or get_db_type_from_pd_type(
 651        ('datetime64[ns, UTC]' if dt_is_utc else 'datetime64[ns]'),
 652        flavor=flavor,
 653    )
 654
 655    da = ""
 656    if flavor in ('postgresql', 'timescaledb', 'cockroachdb', 'citus'):
 657        begin = (
 658            f"CAST({begin} AS {db_type})" if begin != 'now'
 659            else f"CAST(NOW() AT TIME ZONE 'utc' AS {db_type})"
 660        )
 661        if dt_is_utc:
 662            begin += " AT TIME ZONE 'UTC'"
 663        da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '')
 664
 665    elif flavor == 'duckdb':
 666        begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'NOW()'
 667        if dt_is_utc:
 668            begin += " AT TIME ZONE 'UTC'"
 669        da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '')
 670
 671    elif flavor in ('mssql',):
 672        if begin_time and begin_time.microsecond != 0 and not dt_is_utc:
 673            begin = begin[:-4] + "'"
 674        begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'GETUTCDATE()'
 675        if dt_is_utc:
 676            begin += " AT TIME ZONE 'UTC'"
 677        da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin
 678
 679    elif flavor in ('mysql', 'mariadb'):
 680        begin = (
 681            f"CAST({begin} AS DATETIME(6))"
 682            if begin != 'now'
 683            else 'UTC_TIMESTAMP(6)'
 684        )
 685        da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin)
 686
 687    elif flavor == 'sqlite':
 688        da = f"datetime({begin}, '{number} {datepart}')"
 689
 690    elif flavor == 'oracle':
 691        if begin == 'now':
 692            begin = str(
 693                datetime.now(timezone.utc).replace(tzinfo=None).strftime(r'%Y:%m:%d %M:%S.%f')
 694            )
 695        elif begin_time:
 696            begin = str(begin_time.strftime(r'%Y-%m-%d %H:%M:%S.%f'))
 697        dt_format = 'YYYY-MM-DD HH24:MI:SS.FF'
 698        _begin = f"'{begin}'" if begin_time else begin
 699        da = (
 700            (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin)
 701            + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "")
 702        )
 703    return da
 704
 705
 706def test_connection(
 707    self,
 708    **kw: Any
 709) -> Union[bool, None]:
 710    """
 711    Test if a successful connection to the database may be made.
 712
 713    Parameters
 714    ----------
 715    **kw:
 716        The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`.
 717
 718    Returns
 719    -------
 720    `True` if a connection is made, otherwise `False` or `None` in case of failure.
 721
 722    """
 723    import warnings
 724    from meerschaum.connectors.poll import retry_connect
 725    _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self}
 726    _default_kw.update(kw)
 727    with warnings.catch_warnings():
 728        warnings.filterwarnings('ignore', 'Could not')
 729        try:
 730            return retry_connect(**_default_kw)
 731        except Exception:
 732            return False
 733
 734
 735def get_distinct_col_count(
 736    col: str,
 737    query: str,
 738    connector: Optional[mrsm.connectors.sql.SQLConnector] = None,
 739    debug: bool = False
 740) -> Optional[int]:
 741    """
 742    Returns the number of distinct items in a column of a SQL query.
 743
 744    Parameters
 745    ----------
 746    col: str:
 747        The column in the query to count.
 748
 749    query: str:
 750        The SQL query to count from.
 751
 752    connector: Optional[mrsm.connectors.sql.SQLConnector], default None:
 753        The SQLConnector to execute the query.
 754
 755    debug: bool, default False:
 756        Verbosity toggle.
 757
 758    Returns
 759    -------
 760    An `int` of the number of columns in the query or `None` if the query fails.
 761
 762    """
 763    if connector is None:
 764        connector = mrsm.get_connector('sql')
 765
 766    _col_name = sql_item_name(col, connector.flavor, None)
 767
 768    _meta_query = (
 769        f"""
 770        WITH src AS ( {query} ),
 771        dist AS ( SELECT DISTINCT {_col_name} FROM src )
 772        SELECT COUNT(*) FROM dist"""
 773    ) if connector.flavor not in ('mysql', 'mariadb') else (
 774        f"""
 775        SELECT COUNT(*)
 776        FROM (
 777            SELECT DISTINCT {_col_name}
 778            FROM ({query}) AS src
 779        ) AS dist"""
 780    )
 781
 782    result = connector.value(_meta_query, debug=debug)
 783    try:
 784        return int(result)
 785    except Exception:
 786        return None
 787
 788
 789def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str:
 790    """
 791    Parse SQL items depending on the flavor.
 792
 793    Parameters
 794    ----------
 795    item: str
 796        The database item (table, view, etc.) in need of quotes.
 797        
 798    flavor: str
 799        The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.).
 800
 801    schema: Optional[str], default None
 802        If provided, prefix the table name with the schema.
 803
 804    Returns
 805    -------
 806    A `str` which contains the input `item` wrapped in the corresponding escape characters.
 807    
 808    Examples
 809    --------
 810    >>> sql_item_name('table', 'sqlite')
 811    '"table"'
 812    >>> sql_item_name('table', 'mssql')
 813    "[table]"
 814    >>> sql_item_name('table', 'postgresql', schema='abc')
 815    '"abc"."table"'
 816
 817    """
 818    truncated_item = truncate_item_name(str(item), flavor)
 819    if flavor == 'oracle':
 820        truncated_item = pg_capital(truncated_item, quote_capitals=True)
 821        ### NOTE: System-reserved words must be quoted.
 822        if truncated_item.lower() in (
 823            'float', 'varchar', 'nvarchar', 'clob',
 824            'boolean', 'integer', 'table', 'row',
 825        ):
 826            wrappers = ('"', '"')
 827        else:
 828            wrappers = ('', '')
 829    else:
 830        wrappers = table_wrappers.get(flavor, table_wrappers['default'])
 831
 832    ### NOTE: SQLite does not support schemas.
 833    if flavor == 'sqlite':
 834        schema = None
 835    elif flavor == 'mssql' and str(item).startswith('#'):
 836        schema = None
 837
 838    schema_prefix = (
 839        (wrappers[0] + schema + wrappers[1] + '.')
 840        if schema is not None
 841        else ''
 842    )
 843
 844    return schema_prefix + wrappers[0] + truncated_item + wrappers[1]
 845
 846
 847def pg_capital(s: str, quote_capitals: bool = True) -> str:
 848    """
 849    If string contains a capital letter, wrap it in double quotes.
 850    
 851    Parameters
 852    ----------
 853    s: str
 854        The string to be escaped.
 855
 856    quote_capitals: bool, default True
 857        If `False`, do not quote strings with contain only a mix of capital and lower-case letters.
 858
 859    Returns
 860    -------
 861    The input string wrapped in quotes only if it needs them.
 862
 863    Examples
 864    --------
 865    >>> pg_capital("My Table")
 866    '"My Table"'
 867    >>> pg_capital('my_table')
 868    'my_table'
 869
 870    """
 871    if s.startswith('"') and s.endswith('"'):
 872        return s
 873
 874    s = s.replace('"', '')
 875
 876    needs_quotes = s.startswith('_')
 877    if not needs_quotes:
 878        for c in s:
 879            if c == '_':
 880                continue
 881
 882            if not c.isalnum() or (quote_capitals and c.isupper()):
 883                needs_quotes = True
 884                break
 885
 886    if needs_quotes:
 887        return '"' + s + '"'
 888
 889    return s
 890
 891
 892def oracle_capital(s: str) -> str:
 893    """
 894    Capitalize the string of an item on an Oracle database.
 895    """
 896    return s
 897
 898
 899def truncate_item_name(item: str, flavor: str) -> str:
 900    """
 901    Truncate item names to stay within the database flavor's character limit.
 902
 903    Parameters
 904    ----------
 905    item: str
 906        The database item being referenced. This string is the "canonical" name internally.
 907
 908    flavor: str
 909        The flavor of the database on which `item` resides.
 910
 911    Returns
 912    -------
 913    The truncated string.
 914    """
 915    from meerschaum.utils.misc import truncate_string_sections
 916    return truncate_string_sections(
 917        item, max_len=max_name_lens.get(flavor, max_name_lens['default'])
 918    )
 919
 920
 921def build_where(
 922    params: Dict[str, Any],
 923    connector: Optional[mrsm.connectors.sql.SQLConnector] = None,
 924    with_where: bool = True,
 925) -> str:
 926    """
 927    Build the `WHERE` clause based on the input criteria.
 928
 929    Parameters
 930    ----------
 931    params: Dict[str, Any]:
 932        The keywords dictionary to convert into a WHERE clause.
 933        If a value is a string which begins with an underscore, negate that value
 934        (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`).
 935        A value of `_None` will be interpreted as `IS NOT NULL`.
 936
 937    connector: Optional[meerschaum.connectors.sql.SQLConnector], default None:
 938        The Meerschaum SQLConnector that will be executing the query.
 939        The connector is used to extract the SQL dialect.
 940
 941    with_where: bool, default True:
 942        If `True`, include the leading `'WHERE'` string.
 943
 944    Returns
 945    -------
 946    A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor.
 947
 948    Examples
 949    --------
 950    ```
 951    >>> print(build_where({'foo': [1, 2, 3]}))
 952    
 953    WHERE
 954        "foo" IN ('1', '2', '3')
 955    ```
 956    """
 957    import json
 958    from meerschaum.config.static import STATIC_CONFIG
 959    from meerschaum.utils.warnings import warn
 960    from meerschaum.utils.dtypes import value_is_null, none_if_null
 961    negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix']
 962    try:
 963        params_json = json.dumps(params)
 964    except Exception as e:
 965        params_json = str(params)
 966    bad_words = ['drop ', '--', ';']
 967    for word in bad_words:
 968        if word in params_json.lower():
 969            warn(f"Aborting build_where() due to possible SQL injection.")
 970            return ''
 971
 972    if connector is None:
 973        from meerschaum import get_connector
 974        connector = get_connector('sql')
 975    where = ""
 976    leading_and = "\n    AND "
 977    for key, value in params.items():
 978        _key = sql_item_name(key, connector.flavor, None)
 979        ### search across a list (i.e. IN syntax)
 980        if isinstance(value, Iterable) and not isinstance(value, (dict, str)):
 981            includes = [
 982                none_if_null(item)
 983                for item in value
 984                if not str(item).startswith(negation_prefix)
 985            ]
 986            null_includes = [item for item in includes if item is None]
 987            not_null_includes = [item for item in includes if item is not None]
 988            excludes = [
 989                none_if_null(str(item)[len(negation_prefix):])
 990                for item in value
 991                if str(item).startswith(negation_prefix)
 992            ]
 993            null_excludes = [item for item in excludes if item is None]
 994            not_null_excludes = [item for item in excludes if item is not None]
 995
 996            if includes:
 997                where += f"{leading_and}("
 998            if not_null_includes:
 999                where += f"{_key} IN ("
1000                for item in not_null_includes:
1001                    quoted_item = str(item).replace("'", "''")
1002                    where += f"'{quoted_item}', "
1003                where = where[:-2] + ")"
1004            if null_includes:
1005                where += ("\n    OR " if not_null_includes else "") + f"{_key} IS NULL"
1006            if includes:
1007                where += ")"
1008
1009            if excludes:
1010                where += f"{leading_and}("
1011            if not_null_excludes:
1012                where += f"{_key} NOT IN ("
1013                for item in not_null_excludes:
1014                    quoted_item = str(item).replace("'", "''")
1015                    where += f"'{quoted_item}', "
1016                where = where[:-2] + ")"
1017            if null_excludes:
1018                where += ("\n    AND " if not_null_excludes else "") + f"{_key} IS NOT NULL"
1019            if excludes:
1020                where += ")"
1021
1022            continue
1023
1024        ### search a dictionary
1025        elif isinstance(value, dict):
1026            import json
1027            where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'")
1028            continue
1029
1030        eq_sign = '='
1031        is_null = 'IS NULL'
1032        if value_is_null(str(value).lstrip(negation_prefix)):
1033            value = (
1034                (negation_prefix + 'None')
1035                if str(value).startswith(negation_prefix)
1036                else None
1037            )
1038        if str(value).startswith(negation_prefix):
1039            value = str(value)[len(negation_prefix):]
1040            eq_sign = '!='
1041            if value_is_null(value):
1042                value = None
1043                is_null = 'IS NOT NULL'
1044        quoted_value = str(value).replace("'", "''")
1045        where += (
1046            f"{leading_and}{_key} "
1047            + (is_null if value is None else f"{eq_sign} '{quoted_value}'")
1048        )
1049
1050    if len(where) > 1:
1051        where = ("\nWHERE\n    " if with_where else '') + where[len(leading_and):]
1052    return where
1053
1054
1055def table_exists(
1056    table: str,
1057    connector: mrsm.connectors.sql.SQLConnector,
1058    schema: Optional[str] = None,
1059    debug: bool = False,
1060) -> bool:
1061    """Check if a table exists.
1062
1063    Parameters
1064    ----------
1065    table: str:
1066        The name of the table in question.
1067
1068    connector: mrsm.connectors.sql.SQLConnector
1069        The connector to the database which holds the table.
1070
1071    schema: Optional[str], default None
1072        Optionally specify the table schema.
1073        Defaults to `connector.schema`.
1074
1075    debug: bool, default False :
1076        Verbosity toggle.
1077
1078    Returns
1079    -------
1080    A `bool` indicating whether or not the table exists on the database.
1081    """
1082    sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False)
1083    schema = schema or connector.schema
1084    insp = sqlalchemy.inspect(connector.engine)
1085    truncated_table_name = truncate_item_name(str(table), connector.flavor)
1086    return insp.has_table(truncated_table_name, schema=schema)
1087
1088
1089def get_sqlalchemy_table(
1090    table: str,
1091    connector: Optional[mrsm.connectors.sql.SQLConnector] = None,
1092    schema: Optional[str] = None,
1093    refresh: bool = False,
1094    debug: bool = False,
1095) -> Union['sqlalchemy.Table', None]:
1096    """
1097    Construct a SQLAlchemy table from its name.
1098
1099    Parameters
1100    ----------
1101    table: str
1102        The name of the table on the database. Does not need to be escaped.
1103
1104    connector: Optional[meerschaum.connectors.sql.SQLConnector], default None:
1105        The connector to the database which holds the table. 
1106
1107    schema: Optional[str], default None
1108        Specify on which schema the table resides.
1109        Defaults to the schema set in `connector`.
1110
1111    refresh: bool, default False
1112        If `True`, rebuild the cached table object.
1113
1114    debug: bool, default False:
1115        Verbosity toggle.
1116
1117    Returns
1118    -------
1119    A `sqlalchemy.Table` object for the table.
1120
1121    """
1122    if connector is None:
1123        from meerschaum import get_connector
1124        connector = get_connector('sql')
1125
1126    if connector.flavor == 'duckdb':
1127        return None
1128
1129    from meerschaum.connectors.sql.tables import get_tables
1130    from meerschaum.utils.packages import attempt_import
1131    from meerschaum.utils.warnings import warn
1132    if refresh:
1133        connector.metadata.clear()
1134    tables = get_tables(mrsm_instance=connector, debug=debug, create=False)
1135    sqlalchemy = attempt_import('sqlalchemy', lazy=False)
1136    truncated_table_name = truncate_item_name(str(table), connector.flavor)
1137    table_kwargs = {
1138        'autoload_with': connector.engine,
1139    }
1140    if schema:
1141        table_kwargs['schema'] = schema
1142
1143    if refresh or truncated_table_name not in tables:
1144        try:
1145            tables[truncated_table_name] = sqlalchemy.Table(
1146                truncated_table_name,
1147                connector.metadata,
1148                **table_kwargs
1149            )
1150        except sqlalchemy.exc.NoSuchTableError:
1151            warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.")
1152            return None
1153    return tables[truncated_table_name]
1154
1155
1156def get_table_cols_types(
1157    table: str,
1158    connectable: Union[
1159        'mrsm.connectors.sql.SQLConnector',
1160        'sqlalchemy.orm.session.Session',
1161        'sqlalchemy.engine.base.Engine'
1162    ],
1163    flavor: Optional[str] = None,
1164    schema: Optional[str] = None,
1165    database: Optional[str] = None,
1166    debug: bool = False,
1167) -> Dict[str, str]:
1168    """
1169    Return a dictionary mapping a table's columns to data types.
1170    This is useful for inspecting tables creating during a not-yet-committed session.
1171
1172    NOTE: This may return incorrect columns if the schema is not explicitly stated.
1173        Use this function if you are confident the table name is unique or if you have
1174        and explicit schema.
1175        To use the configured schema, get the columns from `get_sqlalchemy_table()` instead.
1176
1177    Parameters
1178    ----------
1179    table: str
1180        The name of the table (unquoted).
1181
1182    connectable: Union[
1183        'mrsm.connectors.sql.SQLConnector',
1184        'sqlalchemy.orm.session.Session',
1185        'sqlalchemy.engine.base.Engine'
1186    ]
1187        The connection object used to fetch the columns and types.
1188
1189    flavor: Optional[str], default None
1190        The database dialect flavor to use for the query.
1191        If omitted, default to `connectable.flavor`.
1192
1193    schema: Optional[str], default None
1194        If provided, restrict the query to this schema.
1195
1196    database: Optional[str]. default None
1197        If provided, restrict the query to this database.
1198
1199    Returns
1200    -------
1201    A dictionary mapping column names to data types.
1202    """
1203    import textwrap
1204    from meerschaum.connectors import SQLConnector
1205    sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False)
1206    flavor = flavor or getattr(connectable, 'flavor', None)
1207    if not flavor:
1208        raise ValueError("Please provide a database flavor.")
1209    if flavor == 'duckdb' and not isinstance(connectable, SQLConnector):
1210        raise ValueError("You must provide a SQLConnector when using DuckDB.")
1211    if flavor in NO_SCHEMA_FLAVORS:
1212        schema = None
1213    if schema is None:
1214        schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None)
1215    if flavor in ('sqlite', 'duckdb', 'oracle'):
1216        database = None
1217    table_trunc = truncate_item_name(table, flavor=flavor)
1218    table_lower = table.lower()
1219    table_upper = table.upper()
1220    table_lower_trunc = truncate_item_name(table_lower, flavor=flavor)
1221    table_upper_trunc = truncate_item_name(table_upper, flavor=flavor)
1222    db_prefix = (
1223        "tempdb."
1224        if flavor == 'mssql' and table.startswith('#')
1225        else ""
1226    )
1227
1228    cols_types_query = sqlalchemy.text(
1229        textwrap.dedent(columns_types_queries.get(
1230            flavor,
1231            columns_types_queries['default']
1232        ).format(
1233            table=table,
1234            table_trunc=table_trunc,
1235            table_lower=table_lower,
1236            table_lower_trunc=table_lower_trunc,
1237            table_upper=table_upper,
1238            table_upper_trunc=table_upper_trunc,
1239            db_prefix=db_prefix,
1240        )).lstrip().rstrip()
1241    )
1242
1243    cols = ['database', 'schema', 'table', 'column', 'type', 'numeric_precision', 'numeric_scale']
1244    result_cols_ix = dict(enumerate(cols))
1245
1246    debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {}
1247    if not debug_kwargs and debug:
1248        dprint(cols_types_query)
1249
1250    try:
1251        result_rows = (
1252            [
1253                row
1254                for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall()
1255            ]
1256            if flavor != 'duckdb'
1257            else [
1258                tuple([doc[col] for col in cols])
1259                for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records')
1260            ]
1261        )
1262        cols_types_docs = [
1263            {
1264                result_cols_ix[i]: val
1265                for i, val in enumerate(row)
1266            }
1267            for row in result_rows
1268        ]
1269        cols_types_docs_filtered = [
1270            doc
1271            for doc in cols_types_docs
1272            if (
1273                (
1274                    not schema
1275                    or doc['schema'] == schema
1276                )
1277                and
1278                (
1279                    not database
1280                    or doc['database'] == database
1281                )
1282            )
1283        ]
1284
1285        ### NOTE: This may return incorrect columns if the schema is not explicitly stated.
1286        if cols_types_docs and not cols_types_docs_filtered:
1287            cols_types_docs_filtered = cols_types_docs
1288
1289        return {
1290            (
1291                doc['column']
1292                if flavor != 'oracle' else (
1293                    (
1294                        doc['column'].lower()
1295                        if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha())
1296                        else doc['column']
1297                    )
1298                )
1299            ): doc['type'].upper() + (
1300                f'({precision},{scale})'
1301                if (
1302                    (precision := doc.get('numeric_precision', None))
1303                     and
1304                    (scale := doc.get('numeric_scale', None))
1305                )
1306                else ''
1307            )
1308            for doc in cols_types_docs_filtered
1309        }
1310    except Exception as e:
1311        warn(f"Failed to fetch columns for table '{table}':\n{e}")
1312        return {}
1313
1314
1315def get_table_cols_indices(
1316    table: str,
1317    connectable: Union[
1318        'mrsm.connectors.sql.SQLConnector',
1319        'sqlalchemy.orm.session.Session',
1320        'sqlalchemy.engine.base.Engine'
1321    ],
1322    flavor: Optional[str] = None,
1323    schema: Optional[str] = None,
1324    database: Optional[str] = None,
1325    debug: bool = False,
1326) -> Dict[str, List[str]]:
1327    """
1328    Return a dictionary mapping a table's columns to lists of indices.
1329    This is useful for inspecting tables creating during a not-yet-committed session.
1330
1331    NOTE: This may return incorrect columns if the schema is not explicitly stated.
1332        Use this function if you are confident the table name is unique or if you have
1333        and explicit schema.
1334        To use the configured schema, get the columns from `get_sqlalchemy_table()` instead.
1335
1336    Parameters
1337    ----------
1338    table: str
1339        The name of the table (unquoted).
1340
1341    connectable: Union[
1342        'mrsm.connectors.sql.SQLConnector',
1343        'sqlalchemy.orm.session.Session',
1344        'sqlalchemy.engine.base.Engine'
1345    ]
1346        The connection object used to fetch the columns and types.
1347
1348    flavor: Optional[str], default None
1349        The database dialect flavor to use for the query.
1350        If omitted, default to `connectable.flavor`.
1351
1352    schema: Optional[str], default None
1353        If provided, restrict the query to this schema.
1354
1355    database: Optional[str]. default None
1356        If provided, restrict the query to this database.
1357
1358    Returns
1359    -------
1360    A dictionary mapping column names to a list of indices.
1361    """
1362    import textwrap
1363    from collections import defaultdict
1364    from meerschaum.connectors import SQLConnector
1365    sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False)
1366    flavor = flavor or getattr(connectable, 'flavor', None)
1367    if not flavor:
1368        raise ValueError("Please provide a database flavor.")
1369    if flavor == 'duckdb' and not isinstance(connectable, SQLConnector):
1370        raise ValueError("You must provide a SQLConnector when using DuckDB.")
1371    if flavor in NO_SCHEMA_FLAVORS:
1372        schema = None
1373    if schema is None:
1374        schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None)
1375    if flavor in ('sqlite', 'duckdb', 'oracle'):
1376        database = None
1377    table_trunc = truncate_item_name(table, flavor=flavor)
1378    table_lower = table.lower()
1379    table_upper = table.upper()
1380    table_lower_trunc = truncate_item_name(table_lower, flavor=flavor)
1381    table_upper_trunc = truncate_item_name(table_upper, flavor=flavor)
1382    db_prefix = (
1383        "tempdb."
1384        if flavor == 'mssql' and table.startswith('#')
1385        else ""
1386    )
1387
1388    cols_indices_query = sqlalchemy.text(
1389        textwrap.dedent(columns_indices_queries.get(
1390            flavor,
1391            columns_indices_queries['default']
1392        ).format(
1393            table=table,
1394            table_trunc=table_trunc,
1395            table_lower=table_lower,
1396            table_lower_trunc=table_lower_trunc,
1397            table_upper=table_upper,
1398            table_upper_trunc=table_upper_trunc,
1399            db_prefix=db_prefix,
1400            schema=schema,
1401        )).lstrip().rstrip()
1402    )
1403
1404    cols = ['database', 'schema', 'table', 'column', 'index', 'index_type']
1405    if flavor == 'mssql':
1406        cols.append('clustered')
1407    result_cols_ix = dict(enumerate(cols))
1408
1409    debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {}
1410    if not debug_kwargs and debug:
1411        dprint(cols_indices_query)
1412
1413    try:
1414        result_rows = (
1415            [
1416                row
1417                for row in connectable.execute(cols_indices_query, **debug_kwargs).fetchall()
1418            ]
1419            if flavor != 'duckdb'
1420            else [
1421                tuple([doc[col] for col in cols])
1422                for doc in connectable.read(cols_indices_query, debug=debug).to_dict(orient='records')
1423            ]
1424        )
1425        cols_types_docs = [
1426            {
1427                result_cols_ix[i]: val
1428                for i, val in enumerate(row)
1429            }
1430            for row in result_rows
1431        ]
1432        cols_types_docs_filtered = [
1433            doc
1434            for doc in cols_types_docs
1435            if (
1436                (
1437                    not schema
1438                    or doc['schema'] == schema
1439                )
1440                and
1441                (
1442                    not database
1443                    or doc['database'] == database
1444                )
1445            )
1446        ]
1447        ### NOTE: This may return incorrect columns if the schema is not explicitly stated.
1448        if cols_types_docs and not cols_types_docs_filtered:
1449            cols_types_docs_filtered = cols_types_docs
1450
1451        cols_indices = defaultdict(lambda: [])
1452        for doc in cols_types_docs_filtered:
1453            col = (
1454                doc['column']
1455                if flavor != 'oracle'
1456                else (
1457                    doc['column'].lower()
1458                    if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha())
1459                    else doc['column']
1460                )
1461            )
1462            index_doc = {
1463                'name': doc.get('index', None),
1464                'type': doc.get('index_type', None)
1465            }
1466            if flavor == 'mssql':
1467                index_doc['clustered'] = doc.get('clustered', None)
1468            cols_indices[col].append(index_doc)
1469
1470        return dict(cols_indices)
1471    except Exception as e:
1472        warn(f"Failed to fetch columns for table '{table}':\n{e}")
1473        return {}
1474
1475
1476def get_update_queries(
1477    target: str,
1478    patch: str,
1479    connectable: Union[
1480        mrsm.connectors.sql.SQLConnector,
1481        'sqlalchemy.orm.session.Session'
1482    ],
1483    join_cols: Iterable[str],
1484    flavor: Optional[str] = None,
1485    upsert: bool = False,
1486    datetime_col: Optional[str] = None,
1487    schema: Optional[str] = None,
1488    patch_schema: Optional[str] = None,
1489    identity_insert: bool = False,
1490    null_indices: bool = True,
1491    cast_columns: bool = True,
1492    debug: bool = False,
1493) -> List[str]:
1494    """
1495    Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table.
1496
1497    Parameters
1498    ----------
1499    target: str
1500        The name of the target table.
1501
1502    patch: str
1503        The name of the patch table. This should have the same shape as the target.
1504
1505    connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session]
1506        The `SQLConnector` or SQLAlchemy session which will later execute the queries.
1507
1508    join_cols: List[str]
1509        The columns to use to join the patch to the target.
1510
1511    flavor: Optional[str], default None
1512        If using a SQLAlchemy session, provide the expected database flavor.
1513
1514    upsert: bool, default False
1515        If `True`, return an upsert query rather than an update.
1516
1517    datetime_col: Optional[str], default None
1518        If provided, bound the join query using this column as the datetime index.
1519        This must be present on both tables.
1520
1521    schema: Optional[str], default None
1522        If provided, use this schema when quoting the target table.
1523        Defaults to `connector.schema`.
1524
1525    patch_schema: Optional[str], default None
1526        If provided, use this schema when quoting the patch table.
1527        Defaults to `schema`.
1528
1529    identity_insert: bool, default False
1530        If `True`, include `SET IDENTITY_INSERT` queries before and after the update queries.
1531        Only applies for MSSQL upserts.
1532
1533    null_indices: bool, default True
1534        If `False`, do not coalesce index columns before joining.
1535
1536    cast_columns: bool, default True
1537        If `False`, do not cast update columns to the target table types.
1538
1539    debug: bool, default False
1540        Verbosity toggle.
1541
1542    Returns
1543    -------
1544    A list of query strings to perform the update operation.
1545    """
1546    import textwrap
1547    from meerschaum.connectors import SQLConnector
1548    from meerschaum.utils.debug import dprint
1549    from meerschaum.utils.dtypes import are_dtypes_equal
1550    from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES, get_pd_type_from_db_type
1551    flavor = flavor or (connectable.flavor if isinstance(connectable, SQLConnector) else None)
1552    if not flavor:
1553        raise ValueError("Provide a flavor if using a SQLAlchemy session.")
1554    if (
1555        flavor == 'sqlite'
1556        and isinstance(connectable, SQLConnector)
1557        and connectable.db_version < '3.33.0'
1558    ):
1559        flavor = 'sqlite_delete_insert'
1560    flavor_key = (f'{flavor}-upsert' if upsert else flavor)
1561    base_queries = UPDATE_QUERIES.get(
1562        flavor_key,
1563        UPDATE_QUERIES['default']
1564    )
1565    if not isinstance(base_queries, list):
1566        base_queries = [base_queries]
1567    schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None)
1568    patch_schema = patch_schema or schema
1569    target_table_columns = get_table_cols_types(
1570        target,
1571        connectable,
1572        flavor=flavor,
1573        schema=schema,
1574        debug=debug,
1575    )
1576    patch_table_columns = get_table_cols_types(
1577        patch,
1578        connectable,
1579        flavor=flavor,
1580        schema=patch_schema,
1581        debug=debug,
1582    )
1583
1584    patch_cols_str = ', '.join(
1585        [
1586            sql_item_name(col, flavor)
1587            for col in patch_table_columns
1588        ]
1589    )
1590    patch_cols_prefixed_str = ', '.join(
1591        [
1592            'p.' + sql_item_name(col, flavor)
1593            for col in patch_table_columns
1594        ]
1595    )
1596
1597    join_cols_str = ', '.join(
1598        [
1599            sql_item_name(col, flavor)
1600            for col in join_cols
1601        ]
1602    )
1603
1604    value_cols = []
1605    join_cols_types = []
1606    if debug:
1607        dprint("target_table_columns:")
1608        mrsm.pprint(target_table_columns)
1609    for c_name, c_type in target_table_columns.items():
1610        if c_name not in patch_table_columns:
1611            continue
1612        if flavor in DB_FLAVORS_CAST_DTYPES:
1613            c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type)
1614        (
1615            join_cols_types
1616            if c_name in join_cols
1617            else value_cols
1618        ).append((c_name, c_type))
1619    if debug:
1620        dprint(f"value_cols: {value_cols}")
1621
1622    if not join_cols_types:
1623        return []
1624    if not value_cols and not upsert:
1625        return []
1626
1627    coalesce_join_cols_str = ', '.join(
1628        [
1629            (
1630                (
1631                    'COALESCE('
1632                    + sql_item_name(c_name, flavor)
1633                    + ', '
1634                    + get_null_replacement(c_type, flavor)
1635                    + ')'
1636                )
1637                if null_indices
1638                else sql_item_name(c_name, flavor)
1639            )
1640            for c_name, c_type in join_cols_types
1641        ]
1642    )
1643
1644    update_or_nothing = ('UPDATE' if value_cols else 'NOTHING')
1645
1646    def sets_subquery(l_prefix: str, r_prefix: str):
1647        if not value_cols:
1648            return ''
1649
1650        utc_value_cols = {
1651            c_name
1652            for c_name, c_type in value_cols
1653            if ('utc' in get_pd_type_from_db_type(c_type).lower())
1654        } if flavor not in TIMEZONE_NAIVE_FLAVORS else set()
1655
1656        cast_func_cols = {
1657            c_name: (
1658                ('', '', '')
1659                if not cast_columns or (
1660                    flavor == 'oracle'
1661                    and are_dtypes_equal(get_pd_type_from_db_type(c_type), 'bytes')
1662                )
1663                else (
1664                    ('CAST(', f" AS {c_type.replace('_', ' ')}", ')' + (
1665                        " AT TIME ZONE 'UTC'"
1666                        if c_name in utc_value_cols
1667                        else ''
1668                    ))
1669                    if flavor != 'sqlite'
1670                    else ('', '', '')
1671                )
1672            )
1673            for c_name, c_type in value_cols
1674        }
1675        return 'SET ' + ',\n'.join([
1676            (
1677                l_prefix + sql_item_name(c_name, flavor, None)
1678                + ' = '
1679                + cast_func_cols[c_name][0]
1680                + r_prefix + sql_item_name(c_name, flavor, None)
1681                + cast_func_cols[c_name][1]
1682                + cast_func_cols[c_name][2]
1683            ) for c_name, c_type in value_cols
1684        ])
1685
1686    def and_subquery(l_prefix: str, r_prefix: str):
1687        return '\n            AND\n                '.join([
1688            (
1689                (
1690                    "COALESCE("
1691                    + l_prefix
1692                    + sql_item_name(c_name, flavor, None)
1693                    + ", "
1694                    + get_null_replacement(c_type, flavor)
1695                    + ")"
1696                    + '\n                =\n                '
1697                    + "COALESCE("
1698                    + r_prefix
1699                    + sql_item_name(c_name, flavor, None)
1700                    + ", "
1701                    + get_null_replacement(c_type, flavor)
1702                    + ")"
1703                )
1704                if null_indices
1705                else (
1706                    l_prefix
1707                    + sql_item_name(c_name, flavor, None)
1708                    + ' = '
1709                    + r_prefix
1710                    + sql_item_name(c_name, flavor, None)
1711                )
1712            ) for c_name, c_type in join_cols_types
1713        ])
1714
1715    skip_query_val = ""
1716    target_table_name = sql_item_name(target, flavor, schema)
1717    patch_table_name = sql_item_name(patch, flavor, patch_schema)
1718    dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None
1719    date_bounds_table = patch_table_name if flavor != 'mssql' else '[date_bounds]'
1720    min_dt_col_name = f"MIN({dt_col_name})" if flavor != 'mssql' else '[Min_dt]'
1721    max_dt_col_name = f"MAX({dt_col_name})" if flavor != 'mssql' else '[Max_dt]'
1722    date_bounds_subquery = (
1723        f"""f.{dt_col_name} >= (SELECT {min_dt_col_name} FROM {date_bounds_table})
1724            AND
1725                f.{dt_col_name} <= (SELECT {max_dt_col_name} FROM {date_bounds_table})"""
1726        if datetime_col
1727        else "1 = 1"
1728    )
1729    with_temp_date_bounds = f"""WITH [date_bounds] AS (
1730        SELECT MIN({dt_col_name}) AS {min_dt_col_name}, MAX({dt_col_name}) AS {max_dt_col_name}
1731        FROM {patch_table_name}
1732    )""" if datetime_col else ""
1733    identity_insert_on = (
1734        f"SET IDENTITY_INSERT {target_table_name} ON"
1735        if identity_insert
1736        else skip_query_val
1737    )
1738    identity_insert_off = (
1739        f"SET IDENTITY_INSERT {target_table_name} OFF"
1740        if identity_insert
1741        else skip_query_val
1742    )
1743
1744    ### NOTE: MSSQL upserts must exclude the update portion if only upserting indices.
1745    when_matched_update_sets_subquery_none = "" if not value_cols else (
1746        "\n        WHEN MATCHED THEN\n"
1747        f"            UPDATE {sets_subquery('', 'p.')}"
1748    )
1749
1750    cols_equal_values = '\n,'.join(
1751        [
1752            f"{sql_item_name(c_name, flavor)} = VALUES({sql_item_name(c_name, flavor)})"
1753            for c_name, c_type in value_cols
1754        ]
1755    )
1756    on_duplicate_key_update = (
1757        "ON DUPLICATE KEY UPDATE"
1758        if value_cols
1759        else ""
1760    )
1761    ignore = "IGNORE " if not value_cols else ""
1762
1763    formatted_queries = [
1764        textwrap.dedent(base_query.format(
1765            sets_subquery_none=sets_subquery('', 'p.'),
1766            sets_subquery_none_excluded=sets_subquery('', 'EXCLUDED.'),
1767            sets_subquery_f=sets_subquery('f.', 'p.'),
1768            and_subquery_f=and_subquery('p.', 'f.'),
1769            and_subquery_t=and_subquery('p.', 't.'),
1770            target_table_name=target_table_name,
1771            patch_table_name=patch_table_name,
1772            patch_cols_str=patch_cols_str,
1773            patch_cols_prefixed_str=patch_cols_prefixed_str,
1774            date_bounds_subquery=date_bounds_subquery,
1775            join_cols_str=join_cols_str,
1776            coalesce_join_cols_str=coalesce_join_cols_str,
1777            update_or_nothing=update_or_nothing,
1778            when_matched_update_sets_subquery_none=when_matched_update_sets_subquery_none,
1779            cols_equal_values=cols_equal_values,
1780            on_duplicate_key_update=on_duplicate_key_update,
1781            ignore=ignore,
1782            with_temp_date_bounds=with_temp_date_bounds,
1783            identity_insert_on=identity_insert_on,
1784            identity_insert_off=identity_insert_off,
1785        )).lstrip().rstrip()
1786        for base_query in base_queries
1787    ]
1788
1789    ### NOTE: Allow for skipping some queries.
1790    return [query for query in formatted_queries if query]
1791
1792
1793def get_null_replacement(typ: str, flavor: str) -> str:
1794    """
1795    Return a value that may temporarily be used in place of NULL for this type.
1796
1797    Parameters
1798    ----------
1799    typ: str
1800        The typ to be converted to NULL.
1801
1802    flavor: str
1803        The database flavor for which this value will be used.
1804
1805    Returns
1806    -------
1807    A value which may stand in place of NULL for this type.
1808    `'None'` is returned if a value cannot be determined.
1809    """
1810    from meerschaum.utils.dtypes import are_dtypes_equal
1811    from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES
1812    if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'):
1813        return '-987654321'
1814    if 'bool' in typ.lower() or typ.lower() == 'bit':
1815        bool_typ = (
1816            PD_TO_DB_DTYPES_FLAVORS
1817            .get('bool', {})
1818            .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default'])
1819        )
1820        if flavor in DB_FLAVORS_CAST_DTYPES:
1821            bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ)
1822        val_to_cast = (
1823            -987654321
1824            if flavor in ('mysql', 'mariadb')
1825            else 0
1826        )
1827        return f'CAST({val_to_cast} AS {bool_typ})'
1828    if 'time' in typ.lower() or 'date' in typ.lower():
1829        db_type = typ if typ.isupper() else None
1830        return dateadd_str(flavor=flavor, begin='1900-01-01', db_type=db_type)
1831    if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',):
1832        return '-987654321.0'
1833    if flavor == 'oracle' and typ.lower().split('(', maxsplit=1)[0] == 'char':
1834        return "'-987654321'"
1835    if flavor == 'oracle' and typ.lower() in ('blob', 'bytes'):
1836        return '00'
1837    if typ.lower() in ('uniqueidentifier', 'guid', 'uuid'):
1838        magic_val = 'DEADBEEF-ABBA-BABE-CAFE-DECAFC0FFEE5'
1839        if flavor == 'mssql':
1840            return f"CAST('{magic_val}' AS UNIQUEIDENTIFIER)"
1841        return f"'{magic_val}'"
1842    return ('n' if flavor == 'oracle' else '') + "'-987654321'"
1843
1844
1845def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]:
1846    """
1847    Fetch the database version if possible.
1848    """
1849    version_name = sql_item_name('version', conn.flavor, None)
1850    version_query = version_queries.get(
1851        conn.flavor,
1852        version_queries['default']
1853    ).format(version_name=version_name)
1854    return conn.value(version_query, debug=debug)
1855
1856
1857def get_rename_table_queries(
1858    old_table: str,
1859    new_table: str,
1860    flavor: str,
1861    schema: Optional[str] = None,
1862) -> List[str]:
1863    """
1864    Return queries to alter a table's name.
1865
1866    Parameters
1867    ----------
1868    old_table: str
1869        The unquoted name of the old table.
1870
1871    new_table: str
1872        The unquoted name of the new table.
1873
1874    flavor: str
1875        The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`.
1876
1877    schema: Optional[str], default None
1878        The schema on which the table resides.
1879
1880    Returns
1881    -------
1882    A list of `ALTER TABLE` or equivalent queries for the database flavor.
1883    """
1884    old_table_name = sql_item_name(old_table, flavor, schema)
1885    new_table_name = sql_item_name(new_table, flavor, None)
1886    tmp_table = '_tmp_rename_' + new_table
1887    tmp_table_name = sql_item_name(tmp_table, flavor, schema)
1888    if flavor == 'mssql':
1889        return [f"EXEC sp_rename '{old_table}', '{new_table}'"]
1890
1891    if_exists_str = "IF EXISTS" if flavor in DROP_IF_EXISTS_FLAVORS else ""
1892    if flavor == 'duckdb':
1893        return (
1894            get_create_table_queries(
1895                f"SELECT * FROM {old_table_name}",
1896                tmp_table,
1897                'duckdb',
1898                schema,
1899            ) + get_create_table_queries(
1900                f"SELECT * FROM {tmp_table_name}",
1901                new_table,
1902                'duckdb',
1903                schema,
1904            ) + [
1905                f"DROP TABLE {if_exists_str} {tmp_table_name}",
1906                f"DROP TABLE {if_exists_str} {old_table_name}",
1907            ]
1908        )
1909
1910    return [f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}"]
1911
1912
1913def get_create_table_query(
1914    query_or_dtypes: Union[str, Dict[str, str]],
1915    new_table: str,
1916    flavor: str,
1917    schema: Optional[str] = None,
1918) -> str:
1919    """
1920    NOTE: This function is deprecated. Use `get_create_table_queries()` instead.
1921
1922    Return a query to create a new table from a `SELECT` query.
1923
1924    Parameters
1925    ----------
1926    query: Union[str, Dict[str, str]]
1927        The select query to use for the creation of the table.
1928        If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns.
1929
1930    new_table: str
1931        The unquoted name of the new table.
1932
1933    flavor: str
1934        The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`).
1935
1936    schema: Optional[str], default None
1937        The schema on which the table will reside.
1938
1939    Returns
1940    -------
1941    A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor.
1942    """
1943    return get_create_table_queries(
1944        query_or_dtypes,
1945        new_table,
1946        flavor,
1947        schema=schema,
1948        primary_key=None,
1949    )[0]
1950
1951
1952def get_create_table_queries(
1953    query_or_dtypes: Union[str, Dict[str, str]],
1954    new_table: str,
1955    flavor: str,
1956    schema: Optional[str] = None,
1957    primary_key: Optional[str] = None,
1958    primary_key_db_type: Optional[str] = None,
1959    autoincrement: bool = False,
1960    datetime_column: Optional[str] = None,
1961) -> List[str]:
1962    """
1963    Return a query to create a new table from a `SELECT` query or a `dtypes` dictionary.
1964
1965    Parameters
1966    ----------
1967    query_or_dtypes: Union[str, Dict[str, str]]
1968        The select query to use for the creation of the table.
1969        If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns.
1970
1971    new_table: str
1972        The unquoted name of the new table.
1973
1974    flavor: str
1975        The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`).
1976
1977    schema: Optional[str], default None
1978        The schema on which the table will reside.
1979
1980    primary_key: Optional[str], default None
1981        If provided, designate this column as the primary key in the new table.
1982
1983    primary_key_db_type: Optional[str], default None
1984        If provided, alter the primary key to this type (to set NOT NULL constraint).
1985
1986    autoincrement: bool, default False
1987        If `True` and `primary_key` is provided, create the `primary_key` column
1988        as an auto-incrementing integer column.
1989
1990    datetime_column: Optional[str], default None
1991        If provided, include this column in the primary key.
1992        Applicable to TimescaleDB only.
1993
1994    Returns
1995    -------
1996    A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor.
1997    """
1998    if not isinstance(query_or_dtypes, (str, dict)):
1999        raise TypeError("`query_or_dtypes` must be a query or a dtypes dictionary.")
2000
2001    method = (
2002        _get_create_table_query_from_cte
2003        if isinstance(query_or_dtypes, str)
2004        else _get_create_table_query_from_dtypes
2005    )
2006    return method(
2007        query_or_dtypes,
2008        new_table,
2009        flavor,
2010        schema=schema,
2011        primary_key=primary_key,
2012        primary_key_db_type=primary_key_db_type,
2013        autoincrement=(autoincrement and flavor not in SKIP_AUTO_INCREMENT_FLAVORS),
2014        datetime_column=datetime_column,
2015    )
2016
2017
2018def _get_create_table_query_from_dtypes(
2019    dtypes: Dict[str, str],
2020    new_table: str,
2021    flavor: str,
2022    schema: Optional[str] = None,
2023    primary_key: Optional[str] = None,
2024    primary_key_db_type: Optional[str] = None,
2025    autoincrement: bool = False,
2026    datetime_column: Optional[str] = None,
2027) -> List[str]:
2028    """
2029    Create a new table from a `dtypes` dictionary.
2030    """
2031    from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, AUTO_INCREMENT_COLUMN_FLAVORS
2032    if not dtypes and not primary_key:
2033        raise ValueError(f"Expecting columns for table '{new_table}'.")
2034
2035    if flavor in SKIP_AUTO_INCREMENT_FLAVORS:
2036        autoincrement = False
2037
2038    cols_types = (
2039        [
2040            (
2041                primary_key,
2042                get_db_type_from_pd_type(dtypes.get(primary_key, 'int') or 'int', flavor=flavor)
2043            )
2044        ]
2045        if primary_key
2046        else []
2047    ) + [
2048        (col, get_db_type_from_pd_type(typ, flavor=flavor))
2049        for col, typ in dtypes.items()
2050        if col != primary_key
2051    ]
2052
2053    table_name = sql_item_name(new_table, schema=schema, flavor=flavor)
2054    primary_key_name = sql_item_name(primary_key, flavor) if primary_key else None
2055    primary_key_constraint_name = (
2056        sql_item_name(f'PK_{new_table}', flavor, None)
2057        if primary_key
2058        else None
2059    )
2060    datetime_column_name = sql_item_name(datetime_column, flavor) if datetime_column else None
2061    primary_key_clustered = (
2062        "CLUSTERED"
2063        if not datetime_column or datetime_column == primary_key
2064        else "NONCLUSTERED"
2065    )
2066    query = f"CREATE TABLE {table_name} ("
2067    if primary_key:
2068        col_db_type = cols_types[0][1]
2069        auto_increment_str = (' ' + AUTO_INCREMENT_COLUMN_FLAVORS.get(
2070            flavor,
2071            AUTO_INCREMENT_COLUMN_FLAVORS['default']
2072        )) if autoincrement or primary_key not in dtypes else ''
2073        col_name = sql_item_name(primary_key, flavor=flavor, schema=None)
2074
2075        if flavor == 'sqlite':
2076            query += (
2077                f"\n    {col_name} "
2078                + (f"{col_db_type}" if not auto_increment_str else 'INTEGER')
2079                + f" PRIMARY KEY{auto_increment_str} NOT NULL,"
2080            )
2081        elif flavor == 'oracle':
2082            query += f"\n    {col_name} {col_db_type} {auto_increment_str} PRIMARY KEY,"
2083        elif flavor == 'timescaledb' and datetime_column and datetime_column != primary_key:
2084            query += f"\n    {col_name} {col_db_type}{auto_increment_str} NOT NULL,"
2085        elif flavor == 'mssql':
2086            query += f"\n    {col_name} {col_db_type}{auto_increment_str} NOT NULL,"
2087        else:
2088            query += f"\n    {col_name} {col_db_type} PRIMARY KEY{auto_increment_str} NOT NULL,"
2089
2090    for col, db_type in cols_types:
2091        if col == primary_key:
2092            continue
2093        col_name = sql_item_name(col, schema=None, flavor=flavor)
2094        query += f"\n    {col_name} {db_type},"
2095    if (
2096        flavor == 'timescaledb'
2097        and datetime_column
2098        and primary_key
2099        and datetime_column != primary_key
2100    ):
2101        query += f"\n    PRIMARY KEY({datetime_column_name}, {primary_key_name}),"
2102
2103    if flavor == 'mssql' and primary_key:
2104        query += f"\n    CONSTRAINT {primary_key_constraint_name} PRIMARY KEY {primary_key_clustered} ({primary_key_name}),"
2105
2106    query = query[:-1]
2107    query += "\n)"
2108
2109    queries = [query]
2110    return queries
2111
2112
2113def _get_create_table_query_from_cte(
2114    query: str,
2115    new_table: str,
2116    flavor: str,
2117    schema: Optional[str] = None,
2118    primary_key: Optional[str] = None,
2119    primary_key_db_type: Optional[str] = None,
2120    autoincrement: bool = False,
2121    datetime_column: Optional[str] = None,
2122) -> List[str]:
2123    """
2124    Create a new table from a CTE query.
2125    """
2126    import textwrap
2127    create_cte = 'create_query'
2128    create_cte_name = sql_item_name(create_cte, flavor, None)
2129    new_table_name = sql_item_name(new_table, flavor, schema)
2130    primary_key_constraint_name = (
2131        sql_item_name(f'PK_{new_table}', flavor, None)
2132        if primary_key
2133        else None
2134    )
2135    primary_key_name = (
2136        sql_item_name(primary_key, flavor, None)
2137        if primary_key
2138        else None
2139    )
2140    primary_key_clustered = (
2141        "CLUSTERED"
2142        if not datetime_column or datetime_column == primary_key
2143        else "NONCLUSTERED"
2144    )
2145    datetime_column_name = (
2146        sql_item_name(datetime_column, flavor)
2147        if datetime_column
2148        else None
2149    )
2150    if flavor in ('mssql',):
2151        query = query.lstrip()
2152        if query.lower().startswith('with '):
2153            final_select_ix = query.lower().rfind('select')
2154            create_table_queries = [
2155                (
2156                    query[:final_select_ix].rstrip() + ',\n'
2157                    + f"{create_cte_name} AS (\n"
2158                    + textwrap.indent(query[final_select_ix:], '    ')
2159                    + "\n)\n"
2160                    + f"SELECT *\nINTO {new_table_name}\nFROM {create_cte_name}"
2161                ),
2162            ]
2163        else:
2164            create_table_queries = [
2165                (
2166                    "SELECT *\n"
2167                    f"INTO {new_table_name}\n"
2168                    f"FROM (\n{textwrap.indent(query, '    ')}\n) AS {create_cte_name}"
2169                ),
2170            ]
2171
2172        alter_type_queries = []
2173        if primary_key_db_type:
2174            alter_type_queries.extend([
2175                (
2176                    f"ALTER TABLE {new_table_name}\n"
2177                    f"ALTER COLUMN {primary_key_name} {primary_key_db_type} NOT NULL"
2178                ),
2179            ])
2180        alter_type_queries.extend([
2181            (
2182                f"ALTER TABLE {new_table_name}\n"
2183                f"ADD CONSTRAINT {primary_key_constraint_name} "
2184                f"PRIMARY KEY {primary_key_clustered} ({primary_key_name})"
2185            ),
2186        ])
2187    elif flavor in (None,):
2188        create_table_queries = [
2189            (
2190                f"WITH {create_cte_name} AS (\n{textwrap.index(query, '    ')}\n)\n"
2191                f"CREATE TABLE {new_table_name} AS\n"
2192                "SELECT *\n"
2193                f"FROM {create_cte_name}"
2194            ),
2195        ]
2196
2197        alter_type_queries = [
2198            (
2199                f"ALTER TABLE {new_table_name}\n"
2200                f"ADD PRIMARY KEY ({primary_key_name})"
2201            ),
2202        ]
2203    elif flavor in ('sqlite', 'mysql', 'mariadb', 'duckdb', 'oracle'):
2204        create_table_queries = [
2205            (
2206                f"CREATE TABLE {new_table_name} AS\n"
2207                "SELECT *\n"
2208                f"FROM (\n{textwrap.indent(query, '    ')}\n)"
2209                + (f" AS {create_cte_name}" if flavor != 'oracle' else '')
2210            ),
2211        ]
2212
2213        alter_type_queries = [
2214            (
2215                f"ALTER TABLE {new_table_name}\n"
2216                "ADD PRIMARY KEY ({primary_key_name})"
2217            ),
2218        ]
2219    elif flavor == 'timescaledb' and datetime_column and datetime_column != primary_key:
2220        create_table_queries = [
2221            (
2222                "SELECT *\n"
2223                f"INTO {new_table_name}\n"
2224                f"FROM (\n{textwrap.indent(query, '    ')}\n) AS {create_cte_name}\n"
2225            ),
2226        ]
2227
2228        alter_type_queries = [
2229            (
2230                f"ALTER TABLE {new_table_name}\n"
2231                f"ADD PRIMARY KEY ({datetime_column_name}, {primary_key_name})"
2232            ),
2233        ]
2234    else:
2235        create_table_queries = [
2236            (
2237                "SELECT *\n"
2238                f"INTO {new_table_name}\n"
2239                f"FROM (\n{textwrap.indent(query, '    ')}\n) AS {create_cte_name}"
2240            ),
2241        ]
2242
2243        alter_type_queries = [
2244            (
2245                f"ALTER TABLE {new_table_name}\n"
2246                f"ADD PRIMARY KEY ({primary_key_name})"
2247            ),
2248        ]
2249
2250    if not primary_key:
2251        return create_table_queries
2252
2253    return create_table_queries + alter_type_queries
2254
2255
2256def wrap_query_with_cte(
2257    sub_query: str,
2258    parent_query: str,
2259    flavor: str,
2260    cte_name: str = "src",
2261) -> str:
2262    """
2263    Wrap a subquery in a CTE and append an encapsulating query.
2264
2265    Parameters
2266    ----------
2267    sub_query: str
2268        The query to be referenced. This may itself contain CTEs.
2269        Unless `cte_name` is provided, this will be aliased as `src`.
2270
2271    parent_query: str
2272        The larger query to append which references the subquery.
2273        This must not contain CTEs.
2274
2275    flavor: str
2276        The database flavor, e.g. `'mssql'`.
2277
2278    cte_name: str, default 'src'
2279        The CTE alias, defaults to `src`.
2280
2281    Returns
2282    -------
2283    An encapsulating query which allows you to treat `sub_query` as a temporary table.
2284
2285    Examples
2286    --------
2287
2288    ```python
2289    from meerschaum.utils.sql import wrap_query_with_cte
2290    sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo"
2291    parent_query = "SELECT newval * 3 FROM src"
2292    query = wrap_query_with_cte(sub_query, parent_query, 'mssql')
2293    print(query)
2294    # WITH foo AS (SELECT 1 AS val),
2295    # [src] AS (
2296    #     SELECT (val * 2) AS newval FROM foo
2297    # )
2298    # SELECT newval * 3 FROM src
2299    ```
2300
2301    """
2302    import textwrap
2303    sub_query = sub_query.lstrip()
2304    cte_name_quoted = sql_item_name(cte_name, flavor, None)
2305
2306    if flavor in NO_CTE_FLAVORS:
2307        return (
2308            parent_query
2309            .replace(cte_name_quoted, '--MRSM_SUBQUERY--')
2310            .replace(cte_name, '--MRSM_SUBQUERY--')
2311            .replace('--MRSM_SUBQUERY--', f"(\n{sub_query}\n) AS {cte_name_quoted}")
2312        )
2313
2314    if sub_query.lstrip().lower().startswith('with '):
2315        final_select_ix = sub_query.lower().rfind('select')
2316        return (
2317            sub_query[:final_select_ix].rstrip() + ',\n'
2318            + f"{cte_name_quoted} AS (\n"
2319            + '    ' + sub_query[final_select_ix:]
2320            + "\n)\n"
2321            + parent_query
2322        )
2323
2324    return (
2325        f"WITH {cte_name_quoted} AS (\n"
2326        f"{textwrap.indent(sub_query, '    ')}\n"
2327        f")\n{parent_query}"
2328    )
2329
2330
2331def format_cte_subquery(
2332    sub_query: str,
2333    flavor: str,
2334    sub_name: str = 'src',
2335    cols_to_select: Union[List[str], str] = '*',
2336) -> str:
2337    """
2338    Given a subquery, build a wrapper query that selects from the CTE subquery.
2339
2340    Parameters
2341    ----------
2342    sub_query: str
2343        The subquery to wrap.
2344
2345    flavor: str
2346        The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`.
2347
2348    sub_name: str, default 'src'
2349        If possible, give this name to the CTE (must be unquoted).
2350
2351    cols_to_select: Union[List[str], str], default ''
2352        If specified, choose which columns to select from the CTE.
2353        If a list of strings is provided, each item will be quoted and joined with commas.
2354        If a string is given, assume it is quoted and insert it into the query.
2355
2356    Returns
2357    -------
2358    A wrapper query that selects from the CTE.
2359    """
2360    quoted_sub_name = sql_item_name(sub_name, flavor, None)
2361    cols_str = (
2362        cols_to_select
2363        if isinstance(cols_to_select, str)
2364        else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select])
2365    )
2366    parent_query = (
2367        f"SELECT {cols_str}\n"
2368        f"FROM {quoted_sub_name}"
2369    )
2370    return wrap_query_with_cte(sub_query, parent_query, flavor, cte_name=sub_name)
2371
2372
2373def session_execute(
2374    session: 'sqlalchemy.orm.session.Session',
2375    queries: Union[List[str], str],
2376    with_results: bool = False,
2377    debug: bool = False,
2378) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]:
2379    """
2380    Similar to `SQLConnector.exec_queries()`, execute a list of queries
2381    and roll back when one fails.
2382
2383    Parameters
2384    ----------
2385    session: sqlalchemy.orm.session.Session
2386        A SQLAlchemy session representing a transaction.
2387
2388    queries: Union[List[str], str]
2389        A query or list of queries to be executed.
2390        If a query fails, roll back the session.
2391
2392    with_results: bool, default False
2393        If `True`, return a list of result objects.
2394
2395    Returns
2396    -------
2397    A `SuccessTuple` indicating the queries were successfully executed.
2398    If `with_results`, return the `SuccessTuple` and a list of results.
2399    """
2400    sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False)
2401    if not isinstance(queries, list):
2402        queries = [queries]
2403    successes, msgs, results = [], [], []
2404    for query in queries:
2405        if debug:
2406            dprint(query)
2407        query_text = sqlalchemy.text(query)
2408        fail_msg = "Failed to execute queries."
2409        try:
2410            result = session.execute(query_text)
2411            query_success = result is not None
2412            query_msg = "Success" if query_success else fail_msg
2413        except Exception as e:
2414            query_success = False
2415            query_msg = f"{fail_msg}\n{e}"
2416            result = None
2417        successes.append(query_success)
2418        msgs.append(query_msg)
2419        results.append(result)
2420        if not query_success:
2421            if debug:
2422                dprint("Rolling back session.")
2423            session.rollback()
2424            break
2425    success, msg = all(successes), '\n'.join(msgs)
2426    if with_results:
2427        return (success, msg), results
2428    return success, msg
2429
2430
2431def get_reset_autoincrement_queries(
2432    table: str,
2433    column: str,
2434    connector: mrsm.connectors.SQLConnector,
2435    schema: Optional[str] = None,
2436    debug: bool = False,
2437) -> List[str]:
2438    """
2439    Return a list of queries to reset a table's auto-increment counter to the next largest value.
2440
2441    Parameters
2442    ----------
2443    table: str
2444        The name of the table on which the auto-incrementing column exists.
2445
2446    column: str
2447        The name of the auto-incrementing column.
2448
2449    connector: mrsm.connectors.SQLConnector
2450        The SQLConnector to the database on which the table exists.
2451
2452    schema: Optional[str], default None
2453        The schema of the table. Defaults to `connector.schema`.
2454
2455    Returns
2456    -------
2457    A list of queries to be executed to reset the auto-incrementing column.
2458    """
2459    if not table_exists(table, connector, schema=schema, debug=debug):
2460        return []
2461
2462    schema = schema or connector.schema
2463    max_id_name = sql_item_name('max_id', connector.flavor)
2464    table_name = sql_item_name(table, connector.flavor, schema)
2465    table_seq_name = sql_item_name(table + '_' + column + '_seq', connector.flavor, schema)
2466    column_name = sql_item_name(column, connector.flavor)
2467    max_id = connector.value(
2468        f"""
2469        SELECT COALESCE(MAX({column_name}), 0) AS {max_id_name}
2470        FROM {table_name}
2471        """,
2472        debug=debug,
2473    )
2474    if max_id is None:
2475        return []
2476
2477    reset_queries = reset_autoincrement_queries.get(
2478        connector.flavor,
2479        reset_autoincrement_queries['default']
2480    )
2481    if not isinstance(reset_queries, list):
2482        reset_queries = [reset_queries]
2483
2484    return [
2485        query.format(
2486            column=column,
2487            column_name=column_name,
2488            table=table,
2489            table_name=table_name,
2490            table_seq_name=table_seq_name,
2491            val=max_id,
2492            val_plus_1=(max_id + 1),
2493        )
2494        for query in reset_queries
2495    ]
test_queries = {'default': 'SELECT 1', 'oracle': 'SELECT 1 FROM DUAL', 'informix': 'SELECT COUNT(*) FROM systables', 'hsqldb': 'SELECT 1 FROM INFORMATION_SCHEMA.SYSTEM_USERS'}
exists_queries = {'default': 'SELECT COUNT(*) FROM {table_name} WHERE 1 = 0'}
version_queries = {'default': 'SELECT VERSION() AS {version_name}', 'sqlite': 'SELECT SQLITE_VERSION() AS {version_name}', 'mssql': 'SELECT @@version', 'oracle': 'SELECT version from PRODUCT_COMPONENT_VERSION WHERE rownum = 1'}
SKIP_IF_EXISTS_FLAVORS = {'oracle', 'mssql'}
DROP_IF_EXISTS_FLAVORS = {'mariadb', 'mysql', 'sqlite', 'mssql', 'postgresql', 'timescaledb', 'citus'}
DROP_INDEX_IF_EXISTS_FLAVORS = {'sqlite', 'mssql', 'postgresql', 'timescaledb', 'citus'}
SKIP_AUTO_INCREMENT_FLAVORS = {'duckdb', 'citus'}
COALESCE_UNIQUE_INDEX_FLAVORS = {'timescaledb', 'citus', 'postgresql'}
UPDATE_QUERIES = {'default': '\n UPDATE {target_table_name} AS f\n {sets_subquery_none}\n FROM {target_table_name} AS t\n INNER JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p\n ON\n {and_subquery_t}\n WHERE\n {and_subquery_f}\n AND\n {date_bounds_subquery}\n ', 'timescaledb-upsert': '\n INSERT INTO {target_table_name} ({patch_cols_str})\n SELECT {patch_cols_str}\n FROM {patch_table_name}\n ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded}\n ', 'postgresql-upsert': '\n INSERT INTO {target_table_name} ({patch_cols_str})\n SELECT {patch_cols_str}\n FROM {patch_table_name}\n ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded}\n ', 'citus-upsert': '\n INSERT INTO {target_table_name} ({patch_cols_str})\n SELECT {patch_cols_str}\n FROM {patch_table_name}\n ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded}\n ', 'cockroachdb-upsert': '\n INSERT INTO {target_table_name} ({patch_cols_str})\n SELECT {patch_cols_str}\n FROM {patch_table_name}\n ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded}\n ', 'mysql': '\n UPDATE {target_table_name} AS f\n JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p\n ON\n {and_subquery_f}\n {sets_subquery_f}\n WHERE\n {date_bounds_subquery}\n ', 'mysql-upsert': '\n INSERT {ignore}INTO {target_table_name} ({patch_cols_str})\n SELECT {patch_cols_str}\n FROM {patch_table_name}\n {on_duplicate_key_update}\n {cols_equal_values}\n ', 'mariadb': '\n UPDATE {target_table_name} AS f\n JOIN (SELECT {patch_cols_str} FROM {patch_table_name}) AS p\n ON\n {and_subquery_f}\n {sets_subquery_f}\n WHERE\n {date_bounds_subquery}\n ', 'mariadb-upsert': '\n INSERT {ignore}INTO {target_table_name} ({patch_cols_str})\n SELECT {patch_cols_str}\n FROM {patch_table_name}\n {on_duplicate_key_update}\n {cols_equal_values}\n ', 'mssql': '\n {with_temp_date_bounds}\n MERGE {target_table_name} f\n USING (SELECT {patch_cols_str} FROM {patch_table_name}) p\n ON\n {and_subquery_f}\n AND\n {date_bounds_subquery}\n WHEN MATCHED THEN\n UPDATE\n {sets_subquery_none};\n ', 'mssql-upsert': ['{identity_insert_on}', '\n {with_temp_date_bounds}\n MERGE {target_table_name} f\n USING (SELECT {patch_cols_str} FROM {patch_table_name}) p\n ON\n {and_subquery_f}\n AND\n {date_bounds_subquery}{when_matched_update_sets_subquery_none}\n WHEN NOT MATCHED THEN\n INSERT ({patch_cols_str})\n VALUES ({patch_cols_prefixed_str});\n ', '{identity_insert_off}'], 'oracle': '\n MERGE INTO {target_table_name} f\n USING (SELECT {patch_cols_str} FROM {patch_table_name}) p\n ON (\n {and_subquery_f}\n AND\n {date_bounds_subquery}\n )\n WHEN MATCHED THEN\n UPDATE\n {sets_subquery_none}\n ', 'oracle-upsert': '\n MERGE INTO {target_table_name} f\n USING (SELECT {patch_cols_str} FROM {patch_table_name}) p\n ON (\n {and_subquery_f}\n AND\n {date_bounds_subquery}\n ){when_matched_update_sets_subquery_none}\n WHEN NOT MATCHED THEN\n INSERT ({patch_cols_str})\n VALUES ({patch_cols_prefixed_str})\n ', 'sqlite-upsert': '\n INSERT INTO {target_table_name} ({patch_cols_str})\n SELECT {patch_cols_str}\n FROM {patch_table_name}\n WHERE true\n ON CONFLICT ({join_cols_str}) DO {update_or_nothing} {sets_subquery_none_excluded}\n ', 'sqlite_delete_insert': ['\n DELETE FROM {target_table_name} AS f\n WHERE ROWID IN (\n SELECT t.ROWID\n FROM {target_table_name} AS t\n INNER JOIN (SELECT * FROM {patch_table_name}) AS p\n ON {and_subquery_t}\n );\n ', '\n INSERT INTO {target_table_name} AS f\n SELECT {patch_cols_str} FROM {patch_table_name} AS p\n ']}
columns_types_queries = {'default': "\n SELECT\n table_catalog AS database,\n table_schema AS schema,\n table_name AS table,\n column_name AS column,\n data_type AS type,\n numeric_precision,\n numeric_scale\n FROM information_schema.columns\n WHERE table_name IN ('{table}', '{table_trunc}')\n ", 'sqlite': '\n SELECT\n \'\' "database",\n \'\' "schema",\n m.name "table",\n p.name "column",\n p.type "type"\n FROM sqlite_master m\n LEFT OUTER JOIN pragma_table_info(m.name) p\n ON m.name <> p.name\n WHERE m.type = \'table\'\n AND m.name IN (\'{table}\', \'{table_trunc}\')\n ', 'mssql': "\n SELECT\n TABLE_CATALOG AS [database],\n TABLE_SCHEMA AS [schema],\n TABLE_NAME AS [table],\n COLUMN_NAME AS [column],\n DATA_TYPE AS [type],\n NUMERIC_PRECISION AS [numeric_precision],\n NUMERIC_SCALE AS [numeric_scale]\n FROM {db_prefix}INFORMATION_SCHEMA.COLUMNS\n WHERE TABLE_NAME IN (\n '{table}',\n '{table_trunc}'\n )\n\n ", 'mysql': "\n SELECT\n TABLE_SCHEMA `database`,\n TABLE_SCHEMA `schema`,\n TABLE_NAME `table`,\n COLUMN_NAME `column`,\n DATA_TYPE `type`,\n NUMERIC_PRECISION `numeric_precision`,\n NUMERIC_SCALE `numeric_scale`\n FROM INFORMATION_SCHEMA.COLUMNS\n WHERE TABLE_NAME IN ('{table}', '{table_trunc}')\n ", 'mariadb': "\n SELECT\n TABLE_SCHEMA `database`,\n TABLE_SCHEMA `schema`,\n TABLE_NAME `table`,\n COLUMN_NAME `column`,\n DATA_TYPE `type`,\n NUMERIC_PRECISION `numeric_precision`,\n NUMERIC_SCALE `numeric_scale`\n FROM INFORMATION_SCHEMA.COLUMNS\n WHERE TABLE_NAME IN ('{table}', '{table_trunc}')\n ", 'oracle': '\n SELECT\n NULL AS "database",\n NULL AS "schema",\n TABLE_NAME AS "table",\n COLUMN_NAME AS "column",\n DATA_TYPE AS "type",\n DATA_PRECISION AS "numeric_precision",\n DATA_SCALE AS "numeric_scale"\n FROM all_tab_columns\n WHERE TABLE_NAME IN (\n \'{table}\',\n \'{table_trunc}\',\n \'{table_lower}\',\n \'{table_lower_trunc}\',\n \'{table_upper}\',\n \'{table_upper_trunc}\'\n )\n '}
hypertable_queries = {'timescaledb': "SELECT hypertable_size('{table_name}')", 'citus': "SELECT citus_table_size('{table_name}')"}
columns_indices_queries = {'default': '\n SELECT\n current_database() AS "database",\n n.nspname AS "schema",\n t.relname AS "table",\n c.column_name AS "column",\n i.relname AS "index",\n CASE WHEN con.contype = \'p\' THEN \'PRIMARY KEY\' ELSE \'INDEX\' END AS "index_type"\n FROM pg_class t\n INNER JOIN pg_index AS ix\n ON t.oid = ix.indrelid\n INNER JOIN pg_class AS i\n ON i.oid = ix.indexrelid\n INNER JOIN pg_namespace AS n\n ON n.oid = t.relnamespace\n INNER JOIN pg_attribute AS a\n ON a.attnum = ANY(ix.indkey)\n AND a.attrelid = t.oid\n INNER JOIN information_schema.columns AS c\n ON c.column_name = a.attname\n AND c.table_name = t.relname\n AND c.table_schema = n.nspname\n LEFT JOIN pg_constraint AS con\n ON con.conindid = i.oid\n AND con.contype = \'p\'\n WHERE\n t.relname IN (\'{table}\', \'{table_trunc}\')\n AND n.nspname = \'{schema}\'\n ', 'sqlite': '\n WITH indexed_columns AS (\n SELECT\n \'{table}\' AS table_name,\n pi.name AS column_name,\n i.name AS index_name,\n \'INDEX\' AS index_type\n FROM\n sqlite_master AS i,\n pragma_index_info(i.name) AS pi\n WHERE\n i.type = \'index\'\n AND i.tbl_name = \'{table}\'\n ),\n primary_key_columns AS (\n SELECT\n \'{table}\' AS table_name,\n ti.name AS column_name,\n \'PRIMARY_KEY\' AS index_name,\n \'PRIMARY KEY\' AS index_type\n FROM\n pragma_table_info(\'{table}\') AS ti\n WHERE\n ti.pk > 0\n )\n SELECT\n NULL AS "database",\n NULL AS "schema",\n "table_name" AS "table",\n "column_name" AS "column",\n "index_name" AS "index",\n "index_type"\n FROM indexed_columns\n UNION ALL\n SELECT\n NULL AS "database",\n NULL AS "schema",\n table_name AS "table",\n column_name AS "column",\n index_name AS "index",\n index_type\n FROM primary_key_columns\n ', 'mssql': "\n SELECT\n NULL AS [database],\n s.name AS [schema],\n t.name AS [table],\n c.name AS [column],\n i.name AS [index],\n CASE\n WHEN kc.type = 'PK' THEN 'PRIMARY KEY'\n ELSE 'INDEX'\n END AS [index_type],\n CASE\n WHEN i.type = 1 THEN CAST(1 AS BIT)\n ELSE CAST(0 AS BIT)\n END AS [clustered]\n FROM\n sys.schemas s\n INNER JOIN sys.tables t\n ON s.schema_id = t.schema_id\n INNER JOIN sys.indexes i\n ON t.object_id = i.object_id\n INNER JOIN sys.index_columns ic\n ON i.object_id = ic.object_id\n AND i.index_id = ic.index_id\n INNER JOIN sys.columns c\n ON ic.object_id = c.object_id\n AND ic.column_id = c.column_id\n LEFT JOIN sys.key_constraints kc\n ON kc.parent_object_id = i.object_id\n AND kc.type = 'PK'\n AND kc.name = i.name\n WHERE\n t.name IN ('{table}', '{table_trunc}')\n AND s.name = '{schema}'\n AND i.type IN (1, 2)\n ", 'oracle': '\n SELECT\n NULL AS "database",\n ic.table_owner AS "schema",\n ic.table_name AS "table",\n ic.column_name AS "column",\n i.index_name AS "index",\n CASE\n WHEN c.constraint_type = \'P\' THEN \'PRIMARY KEY\'\n WHEN i.uniqueness = \'UNIQUE\' THEN \'UNIQUE INDEX\'\n ELSE \'INDEX\'\n END AS index_type\n FROM\n all_ind_columns ic\n INNER JOIN all_indexes i\n ON ic.index_name = i.index_name\n AND ic.table_owner = i.owner\n LEFT JOIN all_constraints c\n ON i.index_name = c.constraint_name\n AND i.table_owner = c.owner\n AND c.constraint_type = \'P\'\n WHERE ic.table_name IN (\n \'{table}\',\n \'{table_trunc}\',\n \'{table_upper}\',\n \'{table_upper_trunc}\'\n )\n ', 'mysql': "\n SELECT\n TABLE_SCHEMA AS `database`,\n TABLE_SCHEMA AS `schema`,\n TABLE_NAME AS `table`,\n COLUMN_NAME AS `column`,\n INDEX_NAME AS `index`,\n CASE\n WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY'\n ELSE 'INDEX'\n END AS `index_type`\n FROM\n information_schema.STATISTICS\n WHERE\n TABLE_NAME IN ('{table}', '{table_trunc}')\n ", 'mariadb': "\n SELECT\n TABLE_SCHEMA AS `database`,\n TABLE_SCHEMA AS `schema`,\n TABLE_NAME AS `table`,\n COLUMN_NAME AS `column`,\n INDEX_NAME AS `index`,\n CASE\n WHEN NON_UNIQUE = 0 THEN 'PRIMARY KEY'\n ELSE 'INDEX'\n END AS `index_type`\n FROM\n information_schema.STATISTICS\n WHERE\n TABLE_NAME IN ('{table}', '{table_trunc}')\n "}
reset_autoincrement_queries: Dict[str, Union[str, List[str]]] = {'default': "\n SELECT SETVAL(pg_get_serial_sequence('{table}', '{column}'), {val})\n FROM {table_name}\n ", 'mssql': "\n DBCC CHECKIDENT ('{table}', RESEED, {val})\n ", 'mysql': '\n ALTER TABLE {table_name} AUTO_INCREMENT = {val}\n ', 'mariadb': '\n ALTER TABLE {table_name} AUTO_INCREMENT = {val}\n ', 'sqlite': "\n UPDATE sqlite_sequence\n SET seq = {val}\n WHERE name = '{table}'\n ", 'oracle': 'ALTER TABLE {table_name} MODIFY {column_name} GENERATED BY DEFAULT ON NULL AS IDENTITY (START WITH {val_plus_1})'}
table_wrappers = {'default': ('"', '"'), 'timescaledb': ('"', '"'), 'citus': ('"', '"'), 'duckdb': ('"', '"'), 'postgresql': ('"', '"'), 'sqlite': ('"', '"'), 'mysql': ('`', '`'), 'mariadb': ('`', '`'), 'mssql': ('[', ']'), 'cockroachdb': ('"', '"'), 'oracle': ('"', '"')}
max_name_lens = {'default': 64, 'mssql': 128, 'oracle': 30, 'postgresql': 64, 'timescaledb': 64, 'citus': 64, 'cockroachdb': 64, 'sqlite': 1024, 'mysql': 64, 'mariadb': 64}
json_flavors = {'cockroachdb', 'timescaledb', 'citus', 'postgresql'}
NO_SCHEMA_FLAVORS = {'oracle', 'mysql', 'sqlite', 'duckdb', 'mariadb'}
DEFAULT_SCHEMA_FLAVORS = {'postgresql': 'public', 'timescaledb': 'public', 'citus': 'public', 'cockroachdb': 'public', 'mysql': 'mysql', 'mariadb': 'mysql', 'mssql': 'dbo'}
OMIT_NULLSFIRST_FLAVORS = {'mysql', 'mssql', 'mariadb'}
SINGLE_ALTER_TABLE_FLAVORS = {'oracle', 'sqlite', 'duckdb', 'mssql'}
NO_CTE_FLAVORS = {'mysql', 'mariadb'}
NO_SELECT_INTO_FLAVORS = {'oracle', 'mysql', 'sqlite', 'duckdb', 'mariadb'}
def clean(substring: str) -> str:
523def clean(substring: str) -> str:
524    """
525    Ensure a substring is clean enough to be inserted into a SQL query.
526    Raises an exception when banned words are used.
527    """
528    from meerschaum.utils.warnings import error
529    banned_symbols = [';', '--', 'drop ',]
530    for symbol in banned_symbols:
531        if symbol in str(substring).lower():
532            error(f"Invalid string: '{substring}'")

Ensure a substring is clean enough to be inserted into a SQL query. Raises an exception when banned words are used.

def dateadd_str( flavor: str = 'postgresql', datepart: str = 'day', number: Union[int, float] = 0, begin: Union[str, datetime.datetime, int] = 'now', db_type: Optional[str] = None) -> str:
535def dateadd_str(
536    flavor: str = 'postgresql',
537    datepart: str = 'day',
538    number: Union[int, float] = 0,
539    begin: Union[str, datetime, int] = 'now',
540    db_type: Optional[str] = None,
541) -> str:
542    """
543    Generate a `DATEADD` clause depending on database flavor.
544
545    Parameters
546    ----------
547    flavor: str, default `'postgresql'`
548        SQL database flavor, e.g. `'postgresql'`, `'sqlite'`.
549
550        Currently supported flavors:
551
552        - `'postgresql'`
553        - `'timescaledb'`
554        - `'citus'`
555        - `'cockroachdb'`
556        - `'duckdb'`
557        - `'mssql'`
558        - `'mysql'`
559        - `'mariadb'`
560        - `'sqlite'`
561        - `'oracle'`
562
563    datepart: str, default `'day'`
564        Which part of the date to modify. Supported values:
565
566        - `'year'`
567        - `'month'`
568        - `'day'`
569        - `'hour'`
570        - `'minute'`
571        - `'second'`
572
573    number: Union[int, float], default `0`
574        How many units to add to the date part.
575
576    begin: Union[str, datetime], default `'now'`
577        Base datetime to which to add dateparts.
578
579    db_type: Optional[str], default None
580        If provided, cast the datetime string as the type.
581        Otherwise, infer this from the input datetime value.
582
583    Returns
584    -------
585    The appropriate `DATEADD` string for the corresponding database flavor.
586
587    Examples
588    --------
589    >>> dateadd_str(
590    ...     flavor='mssql',
591    ...     begin=datetime(2022, 1, 1, 0, 0),
592    ...     number=1,
593    ... )
594    "DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))"
595    >>> dateadd_str(
596    ...     flavor='postgresql',
597    ...     begin=datetime(2022, 1, 1, 0, 0),
598    ...     number=1,
599    ... )
600    "CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'"
601
602    """
603    from meerschaum.utils.packages import attempt_import
604    from meerschaum.utils.dtypes.sql import get_db_type_from_pd_type, get_pd_type_from_db_type
605    dateutil_parser = attempt_import('dateutil.parser')
606    if 'int' in str(type(begin)).lower():
607        num_str = str(begin)
608        if number is not None and number != 0:
609            num_str += (
610                f' + {number}'
611                if number > 0
612                else f" - {number * -1}"
613            )
614        return num_str
615    if not begin:
616        return ''
617
618    _original_begin = begin
619    begin_time = None
620    ### Sanity check: make sure `begin` is a valid datetime before we inject anything.
621    if not isinstance(begin, datetime):
622        try:
623            begin_time = dateutil_parser.parse(begin)
624        except Exception:
625            begin_time = None
626    else:
627        begin_time = begin
628
629    ### Unable to parse into a datetime.
630    if begin_time is None:
631        ### Throw an error if banned symbols are included in the `begin` string.
632        clean(str(begin))
633    ### If begin is a valid datetime, wrap it in quotes.
634    else:
635        if isinstance(begin, datetime) and begin.tzinfo is not None:
636            begin = begin.astimezone(timezone.utc)
637        begin = (
638            f"'{begin.replace(tzinfo=None)}'"
639            if isinstance(begin, datetime) and flavor in TIMEZONE_NAIVE_FLAVORS
640            else f"'{begin}'"
641        )
642
643    dt_is_utc = (
644        begin_time.tzinfo is not None
645        if begin_time is not None
646        else ('+' in str(begin) or '-' in str(begin).split(':', maxsplit=1)[-1])
647    )
648    if db_type:
649        db_type_is_utc = 'utc' in get_pd_type_from_db_type(db_type).lower()
650        dt_is_utc = dt_is_utc or db_type_is_utc
651    db_type = db_type or get_db_type_from_pd_type(
652        ('datetime64[ns, UTC]' if dt_is_utc else 'datetime64[ns]'),
653        flavor=flavor,
654    )
655
656    da = ""
657    if flavor in ('postgresql', 'timescaledb', 'cockroachdb', 'citus'):
658        begin = (
659            f"CAST({begin} AS {db_type})" if begin != 'now'
660            else f"CAST(NOW() AT TIME ZONE 'utc' AS {db_type})"
661        )
662        if dt_is_utc:
663            begin += " AT TIME ZONE 'UTC'"
664        da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '')
665
666    elif flavor == 'duckdb':
667        begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'NOW()'
668        if dt_is_utc:
669            begin += " AT TIME ZONE 'UTC'"
670        da = begin + (f" + INTERVAL '{number} {datepart}'" if number != 0 else '')
671
672    elif flavor in ('mssql',):
673        if begin_time and begin_time.microsecond != 0 and not dt_is_utc:
674            begin = begin[:-4] + "'"
675        begin = f"CAST({begin} AS {db_type})" if begin != 'now' else 'GETUTCDATE()'
676        if dt_is_utc:
677            begin += " AT TIME ZONE 'UTC'"
678        da = f"DATEADD({datepart}, {number}, {begin})" if number != 0 else begin
679
680    elif flavor in ('mysql', 'mariadb'):
681        begin = (
682            f"CAST({begin} AS DATETIME(6))"
683            if begin != 'now'
684            else 'UTC_TIMESTAMP(6)'
685        )
686        da = (f"DATE_ADD({begin}, INTERVAL {number} {datepart})" if number != 0 else begin)
687
688    elif flavor == 'sqlite':
689        da = f"datetime({begin}, '{number} {datepart}')"
690
691    elif flavor == 'oracle':
692        if begin == 'now':
693            begin = str(
694                datetime.now(timezone.utc).replace(tzinfo=None).strftime(r'%Y:%m:%d %M:%S.%f')
695            )
696        elif begin_time:
697            begin = str(begin_time.strftime(r'%Y-%m-%d %H:%M:%S.%f'))
698        dt_format = 'YYYY-MM-DD HH24:MI:SS.FF'
699        _begin = f"'{begin}'" if begin_time else begin
700        da = (
701            (f"TO_TIMESTAMP({_begin}, '{dt_format}')" if begin_time else _begin)
702            + (f" + INTERVAL '{number}' {datepart}" if number != 0 else "")
703        )
704    return da

Generate a DATEADD clause depending on database flavor.

Parameters
  • flavor (str, default 'postgresql'): SQL database flavor, e.g. 'postgresql', 'sqlite'.

    Currently supported flavors:

    • 'postgresql'
    • 'timescaledb'
    • 'citus'
    • 'cockroachdb'
    • 'duckdb'
    • 'mssql'
    • 'mysql'
    • 'mariadb'
    • 'sqlite'
    • 'oracle'
  • datepart (str, default 'day'): Which part of the date to modify. Supported values:

    • 'year'
    • 'month'
    • 'day'
    • 'hour'
    • 'minute'
    • 'second'
  • number (Union[int, float], default 0): How many units to add to the date part.
  • begin (Union[str, datetime], default 'now'): Base datetime to which to add dateparts.
  • db_type (Optional[str], default None): If provided, cast the datetime string as the type. Otherwise, infer this from the input datetime value.
Returns
  • The appropriate DATEADD string for the corresponding database flavor.
Examples
>>> dateadd_str(
...     flavor='mssql',
...     begin=datetime(2022, 1, 1, 0, 0),
...     number=1,
... )
"DATEADD(day, 1, CAST('2022-01-01 00:00:00' AS DATETIME2))"
>>> dateadd_str(
...     flavor='postgresql',
...     begin=datetime(2022, 1, 1, 0, 0),
...     number=1,
... )
"CAST('2022-01-01 00:00:00' AS TIMESTAMP) + INTERVAL '1 day'"
def test_connection(self, **kw: Any) -> Optional[bool]:
707def test_connection(
708    self,
709    **kw: Any
710) -> Union[bool, None]:
711    """
712    Test if a successful connection to the database may be made.
713
714    Parameters
715    ----------
716    **kw:
717        The keyword arguments are passed to `meerschaum.connectors.poll.retry_connect`.
718
719    Returns
720    -------
721    `True` if a connection is made, otherwise `False` or `None` in case of failure.
722
723    """
724    import warnings
725    from meerschaum.connectors.poll import retry_connect
726    _default_kw = {'max_retries': 1, 'retry_wait': 0, 'warn': False, 'connector': self}
727    _default_kw.update(kw)
728    with warnings.catch_warnings():
729        warnings.filterwarnings('ignore', 'Could not')
730        try:
731            return retry_connect(**_default_kw)
732        except Exception:
733            return False

Test if a successful connection to the database may be made.

Parameters
Returns
  • True if a connection is made, otherwise False or None in case of failure.
def get_distinct_col_count( col: str, query: str, connector: Optional[meerschaum.connectors.SQLConnector] = None, debug: bool = False) -> Optional[int]:
736def get_distinct_col_count(
737    col: str,
738    query: str,
739    connector: Optional[mrsm.connectors.sql.SQLConnector] = None,
740    debug: bool = False
741) -> Optional[int]:
742    """
743    Returns the number of distinct items in a column of a SQL query.
744
745    Parameters
746    ----------
747    col: str:
748        The column in the query to count.
749
750    query: str:
751        The SQL query to count from.
752
753    connector: Optional[mrsm.connectors.sql.SQLConnector], default None:
754        The SQLConnector to execute the query.
755
756    debug: bool, default False:
757        Verbosity toggle.
758
759    Returns
760    -------
761    An `int` of the number of columns in the query or `None` if the query fails.
762
763    """
764    if connector is None:
765        connector = mrsm.get_connector('sql')
766
767    _col_name = sql_item_name(col, connector.flavor, None)
768
769    _meta_query = (
770        f"""
771        WITH src AS ( {query} ),
772        dist AS ( SELECT DISTINCT {_col_name} FROM src )
773        SELECT COUNT(*) FROM dist"""
774    ) if connector.flavor not in ('mysql', 'mariadb') else (
775        f"""
776        SELECT COUNT(*)
777        FROM (
778            SELECT DISTINCT {_col_name}
779            FROM ({query}) AS src
780        ) AS dist"""
781    )
782
783    result = connector.value(_meta_query, debug=debug)
784    try:
785        return int(result)
786    except Exception:
787        return None

Returns the number of distinct items in a column of a SQL query.

Parameters
  • col (str:): The column in the query to count.
  • query (str:): The SQL query to count from.
  • connector (Optional[mrsm.connectors.sql.SQLConnector], default None:): The SQLConnector to execute the query.
  • debug (bool, default False:): Verbosity toggle.
Returns
  • An int of the number of columns in the query or None if the query fails.
def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str:
790def sql_item_name(item: str, flavor: str, schema: Optional[str] = None) -> str:
791    """
792    Parse SQL items depending on the flavor.
793
794    Parameters
795    ----------
796    item: str
797        The database item (table, view, etc.) in need of quotes.
798        
799    flavor: str
800        The database flavor (`'postgresql'`, `'mssql'`, `'sqllite'`, etc.).
801
802    schema: Optional[str], default None
803        If provided, prefix the table name with the schema.
804
805    Returns
806    -------
807    A `str` which contains the input `item` wrapped in the corresponding escape characters.
808    
809    Examples
810    --------
811    >>> sql_item_name('table', 'sqlite')
812    '"table"'
813    >>> sql_item_name('table', 'mssql')
814    "[table]"
815    >>> sql_item_name('table', 'postgresql', schema='abc')
816    '"abc"."table"'
817
818    """
819    truncated_item = truncate_item_name(str(item), flavor)
820    if flavor == 'oracle':
821        truncated_item = pg_capital(truncated_item, quote_capitals=True)
822        ### NOTE: System-reserved words must be quoted.
823        if truncated_item.lower() in (
824            'float', 'varchar', 'nvarchar', 'clob',
825            'boolean', 'integer', 'table', 'row',
826        ):
827            wrappers = ('"', '"')
828        else:
829            wrappers = ('', '')
830    else:
831        wrappers = table_wrappers.get(flavor, table_wrappers['default'])
832
833    ### NOTE: SQLite does not support schemas.
834    if flavor == 'sqlite':
835        schema = None
836    elif flavor == 'mssql' and str(item).startswith('#'):
837        schema = None
838
839    schema_prefix = (
840        (wrappers[0] + schema + wrappers[1] + '.')
841        if schema is not None
842        else ''
843    )
844
845    return schema_prefix + wrappers[0] + truncated_item + wrappers[1]

Parse SQL items depending on the flavor.

Parameters
  • item (str): The database item (table, view, etc.) in need of quotes.
  • flavor (str): The database flavor ('postgresql', 'mssql', 'sqllite', etc.).
  • schema (Optional[str], default None): If provided, prefix the table name with the schema.
Returns
  • A str which contains the input item wrapped in the corresponding escape characters.
Examples
>>> sql_item_name('table', 'sqlite')
'"table"'
>>> sql_item_name('table', 'mssql')
"[table]"
>>> sql_item_name('table', 'postgresql', schema='abc')
'"abc"."table"'
def pg_capital(s: str, quote_capitals: bool = True) -> str:
848def pg_capital(s: str, quote_capitals: bool = True) -> str:
849    """
850    If string contains a capital letter, wrap it in double quotes.
851    
852    Parameters
853    ----------
854    s: str
855        The string to be escaped.
856
857    quote_capitals: bool, default True
858        If `False`, do not quote strings with contain only a mix of capital and lower-case letters.
859
860    Returns
861    -------
862    The input string wrapped in quotes only if it needs them.
863
864    Examples
865    --------
866    >>> pg_capital("My Table")
867    '"My Table"'
868    >>> pg_capital('my_table')
869    'my_table'
870
871    """
872    if s.startswith('"') and s.endswith('"'):
873        return s
874
875    s = s.replace('"', '')
876
877    needs_quotes = s.startswith('_')
878    if not needs_quotes:
879        for c in s:
880            if c == '_':
881                continue
882
883            if not c.isalnum() or (quote_capitals and c.isupper()):
884                needs_quotes = True
885                break
886
887    if needs_quotes:
888        return '"' + s + '"'
889
890    return s

If string contains a capital letter, wrap it in double quotes.

Parameters
  • s (str): The string to be escaped.
  • quote_capitals (bool, default True): If False, do not quote strings with contain only a mix of capital and lower-case letters.
Returns
  • The input string wrapped in quotes only if it needs them.
Examples
>>> pg_capital("My Table")
'"My Table"'
>>> pg_capital('my_table')
'my_table'
def oracle_capital(s: str) -> str:
893def oracle_capital(s: str) -> str:
894    """
895    Capitalize the string of an item on an Oracle database.
896    """
897    return s

Capitalize the string of an item on an Oracle database.

def truncate_item_name(item: str, flavor: str) -> str:
900def truncate_item_name(item: str, flavor: str) -> str:
901    """
902    Truncate item names to stay within the database flavor's character limit.
903
904    Parameters
905    ----------
906    item: str
907        The database item being referenced. This string is the "canonical" name internally.
908
909    flavor: str
910        The flavor of the database on which `item` resides.
911
912    Returns
913    -------
914    The truncated string.
915    """
916    from meerschaum.utils.misc import truncate_string_sections
917    return truncate_string_sections(
918        item, max_len=max_name_lens.get(flavor, max_name_lens['default'])
919    )

Truncate item names to stay within the database flavor's character limit.

Parameters
  • item (str): The database item being referenced. This string is the "canonical" name internally.
  • flavor (str): The flavor of the database on which item resides.
Returns
  • The truncated string.
def build_where( params: Dict[str, Any], connector: Optional[meerschaum.connectors.SQLConnector] = None, with_where: bool = True) -> str:
 922def build_where(
 923    params: Dict[str, Any],
 924    connector: Optional[mrsm.connectors.sql.SQLConnector] = None,
 925    with_where: bool = True,
 926) -> str:
 927    """
 928    Build the `WHERE` clause based on the input criteria.
 929
 930    Parameters
 931    ----------
 932    params: Dict[str, Any]:
 933        The keywords dictionary to convert into a WHERE clause.
 934        If a value is a string which begins with an underscore, negate that value
 935        (e.g. `!=` instead of `=` or `NOT IN` instead of `IN`).
 936        A value of `_None` will be interpreted as `IS NOT NULL`.
 937
 938    connector: Optional[meerschaum.connectors.sql.SQLConnector], default None:
 939        The Meerschaum SQLConnector that will be executing the query.
 940        The connector is used to extract the SQL dialect.
 941
 942    with_where: bool, default True:
 943        If `True`, include the leading `'WHERE'` string.
 944
 945    Returns
 946    -------
 947    A `str` of the `WHERE` clause from the input `params` dictionary for the connector's flavor.
 948
 949    Examples
 950    --------
 951    ```
 952    >>> print(build_where({'foo': [1, 2, 3]}))
 953    
 954    WHERE
 955        "foo" IN ('1', '2', '3')
 956    ```
 957    """
 958    import json
 959    from meerschaum.config.static import STATIC_CONFIG
 960    from meerschaum.utils.warnings import warn
 961    from meerschaum.utils.dtypes import value_is_null, none_if_null
 962    negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix']
 963    try:
 964        params_json = json.dumps(params)
 965    except Exception as e:
 966        params_json = str(params)
 967    bad_words = ['drop ', '--', ';']
 968    for word in bad_words:
 969        if word in params_json.lower():
 970            warn(f"Aborting build_where() due to possible SQL injection.")
 971            return ''
 972
 973    if connector is None:
 974        from meerschaum import get_connector
 975        connector = get_connector('sql')
 976    where = ""
 977    leading_and = "\n    AND "
 978    for key, value in params.items():
 979        _key = sql_item_name(key, connector.flavor, None)
 980        ### search across a list (i.e. IN syntax)
 981        if isinstance(value, Iterable) and not isinstance(value, (dict, str)):
 982            includes = [
 983                none_if_null(item)
 984                for item in value
 985                if not str(item).startswith(negation_prefix)
 986            ]
 987            null_includes = [item for item in includes if item is None]
 988            not_null_includes = [item for item in includes if item is not None]
 989            excludes = [
 990                none_if_null(str(item)[len(negation_prefix):])
 991                for item in value
 992                if str(item).startswith(negation_prefix)
 993            ]
 994            null_excludes = [item for item in excludes if item is None]
 995            not_null_excludes = [item for item in excludes if item is not None]
 996
 997            if includes:
 998                where += f"{leading_and}("
 999            if not_null_includes:
1000                where += f"{_key} IN ("
1001                for item in not_null_includes:
1002                    quoted_item = str(item).replace("'", "''")
1003                    where += f"'{quoted_item}', "
1004                where = where[:-2] + ")"
1005            if null_includes:
1006                where += ("\n    OR " if not_null_includes else "") + f"{_key} IS NULL"
1007            if includes:
1008                where += ")"
1009
1010            if excludes:
1011                where += f"{leading_and}("
1012            if not_null_excludes:
1013                where += f"{_key} NOT IN ("
1014                for item in not_null_excludes:
1015                    quoted_item = str(item).replace("'", "''")
1016                    where += f"'{quoted_item}', "
1017                where = where[:-2] + ")"
1018            if null_excludes:
1019                where += ("\n    AND " if not_null_excludes else "") + f"{_key} IS NOT NULL"
1020            if excludes:
1021                where += ")"
1022
1023            continue
1024
1025        ### search a dictionary
1026        elif isinstance(value, dict):
1027            import json
1028            where += (f"{leading_and}CAST({_key} AS TEXT) = '" + json.dumps(value) + "'")
1029            continue
1030
1031        eq_sign = '='
1032        is_null = 'IS NULL'
1033        if value_is_null(str(value).lstrip(negation_prefix)):
1034            value = (
1035                (negation_prefix + 'None')
1036                if str(value).startswith(negation_prefix)
1037                else None
1038            )
1039        if str(value).startswith(negation_prefix):
1040            value = str(value)[len(negation_prefix):]
1041            eq_sign = '!='
1042            if value_is_null(value):
1043                value = None
1044                is_null = 'IS NOT NULL'
1045        quoted_value = str(value).replace("'", "''")
1046        where += (
1047            f"{leading_and}{_key} "
1048            + (is_null if value is None else f"{eq_sign} '{quoted_value}'")
1049        )
1050
1051    if len(where) > 1:
1052        where = ("\nWHERE\n    " if with_where else '') + where[len(leading_and):]
1053    return where

Build the WHERE clause based on the input criteria.

Parameters
  • params (Dict[str, Any]:): The keywords dictionary to convert into a WHERE clause. If a value is a string which begins with an underscore, negate that value (e.g. != instead of = or NOT IN instead of IN). A value of _None will be interpreted as IS NOT NULL.
  • connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The Meerschaum SQLConnector that will be executing the query. The connector is used to extract the SQL dialect.
  • with_where (bool, default True:): If True, include the leading 'WHERE' string.
Returns
  • A str of the WHERE clause from the input params dictionary for the connector's flavor.
Examples
>>> print(build_where({'foo': [1, 2, 3]}))

WHERE
    "foo" IN ('1', '2', '3')
def table_exists( table: str, connector: meerschaum.connectors.SQLConnector, schema: Optional[str] = None, debug: bool = False) -> bool:
1056def table_exists(
1057    table: str,
1058    connector: mrsm.connectors.sql.SQLConnector,
1059    schema: Optional[str] = None,
1060    debug: bool = False,
1061) -> bool:
1062    """Check if a table exists.
1063
1064    Parameters
1065    ----------
1066    table: str:
1067        The name of the table in question.
1068
1069    connector: mrsm.connectors.sql.SQLConnector
1070        The connector to the database which holds the table.
1071
1072    schema: Optional[str], default None
1073        Optionally specify the table schema.
1074        Defaults to `connector.schema`.
1075
1076    debug: bool, default False :
1077        Verbosity toggle.
1078
1079    Returns
1080    -------
1081    A `bool` indicating whether or not the table exists on the database.
1082    """
1083    sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False)
1084    schema = schema or connector.schema
1085    insp = sqlalchemy.inspect(connector.engine)
1086    truncated_table_name = truncate_item_name(str(table), connector.flavor)
1087    return insp.has_table(truncated_table_name, schema=schema)

Check if a table exists.

Parameters
  • table (str:): The name of the table in question.
  • connector (mrsm.connectors.sql.SQLConnector): The connector to the database which holds the table.
  • schema (Optional[str], default None): Optionally specify the table schema. Defaults to connector.schema.
  • debug (bool, default False :): Verbosity toggle.
Returns
  • A bool indicating whether or not the table exists on the database.
def get_sqlalchemy_table( table: str, connector: Optional[meerschaum.connectors.SQLConnector] = None, schema: Optional[str] = None, refresh: bool = False, debug: bool = False) -> "Union['sqlalchemy.Table', None]":
1090def get_sqlalchemy_table(
1091    table: str,
1092    connector: Optional[mrsm.connectors.sql.SQLConnector] = None,
1093    schema: Optional[str] = None,
1094    refresh: bool = False,
1095    debug: bool = False,
1096) -> Union['sqlalchemy.Table', None]:
1097    """
1098    Construct a SQLAlchemy table from its name.
1099
1100    Parameters
1101    ----------
1102    table: str
1103        The name of the table on the database. Does not need to be escaped.
1104
1105    connector: Optional[meerschaum.connectors.sql.SQLConnector], default None:
1106        The connector to the database which holds the table. 
1107
1108    schema: Optional[str], default None
1109        Specify on which schema the table resides.
1110        Defaults to the schema set in `connector`.
1111
1112    refresh: bool, default False
1113        If `True`, rebuild the cached table object.
1114
1115    debug: bool, default False:
1116        Verbosity toggle.
1117
1118    Returns
1119    -------
1120    A `sqlalchemy.Table` object for the table.
1121
1122    """
1123    if connector is None:
1124        from meerschaum import get_connector
1125        connector = get_connector('sql')
1126
1127    if connector.flavor == 'duckdb':
1128        return None
1129
1130    from meerschaum.connectors.sql.tables import get_tables
1131    from meerschaum.utils.packages import attempt_import
1132    from meerschaum.utils.warnings import warn
1133    if refresh:
1134        connector.metadata.clear()
1135    tables = get_tables(mrsm_instance=connector, debug=debug, create=False)
1136    sqlalchemy = attempt_import('sqlalchemy', lazy=False)
1137    truncated_table_name = truncate_item_name(str(table), connector.flavor)
1138    table_kwargs = {
1139        'autoload_with': connector.engine,
1140    }
1141    if schema:
1142        table_kwargs['schema'] = schema
1143
1144    if refresh or truncated_table_name not in tables:
1145        try:
1146            tables[truncated_table_name] = sqlalchemy.Table(
1147                truncated_table_name,
1148                connector.metadata,
1149                **table_kwargs
1150            )
1151        except sqlalchemy.exc.NoSuchTableError:
1152            warn(f"Table '{truncated_table_name}' does not exist in '{connector}'.")
1153            return None
1154    return tables[truncated_table_name]

Construct a SQLAlchemy table from its name.

Parameters
  • table (str): The name of the table on the database. Does not need to be escaped.
  • connector (Optional[meerschaum.connectors.sql.SQLConnector], default None:): The connector to the database which holds the table.
  • schema (Optional[str], default None): Specify on which schema the table resides. Defaults to the schema set in connector.
  • refresh (bool, default False): If True, rebuild the cached table object.
  • debug (bool, default False:): Verbosity toggle.
Returns
  • A sqlalchemy.Table object for the table.
def get_table_cols_types( table: str, connectable: "Union['mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine']", flavor: Optional[str] = None, schema: Optional[str] = None, database: Optional[str] = None, debug: bool = False) -> Dict[str, str]:
1157def get_table_cols_types(
1158    table: str,
1159    connectable: Union[
1160        'mrsm.connectors.sql.SQLConnector',
1161        'sqlalchemy.orm.session.Session',
1162        'sqlalchemy.engine.base.Engine'
1163    ],
1164    flavor: Optional[str] = None,
1165    schema: Optional[str] = None,
1166    database: Optional[str] = None,
1167    debug: bool = False,
1168) -> Dict[str, str]:
1169    """
1170    Return a dictionary mapping a table's columns to data types.
1171    This is useful for inspecting tables creating during a not-yet-committed session.
1172
1173    NOTE: This may return incorrect columns if the schema is not explicitly stated.
1174        Use this function if you are confident the table name is unique or if you have
1175        and explicit schema.
1176        To use the configured schema, get the columns from `get_sqlalchemy_table()` instead.
1177
1178    Parameters
1179    ----------
1180    table: str
1181        The name of the table (unquoted).
1182
1183    connectable: Union[
1184        'mrsm.connectors.sql.SQLConnector',
1185        'sqlalchemy.orm.session.Session',
1186        'sqlalchemy.engine.base.Engine'
1187    ]
1188        The connection object used to fetch the columns and types.
1189
1190    flavor: Optional[str], default None
1191        The database dialect flavor to use for the query.
1192        If omitted, default to `connectable.flavor`.
1193
1194    schema: Optional[str], default None
1195        If provided, restrict the query to this schema.
1196
1197    database: Optional[str]. default None
1198        If provided, restrict the query to this database.
1199
1200    Returns
1201    -------
1202    A dictionary mapping column names to data types.
1203    """
1204    import textwrap
1205    from meerschaum.connectors import SQLConnector
1206    sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False)
1207    flavor = flavor or getattr(connectable, 'flavor', None)
1208    if not flavor:
1209        raise ValueError("Please provide a database flavor.")
1210    if flavor == 'duckdb' and not isinstance(connectable, SQLConnector):
1211        raise ValueError("You must provide a SQLConnector when using DuckDB.")
1212    if flavor in NO_SCHEMA_FLAVORS:
1213        schema = None
1214    if schema is None:
1215        schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None)
1216    if flavor in ('sqlite', 'duckdb', 'oracle'):
1217        database = None
1218    table_trunc = truncate_item_name(table, flavor=flavor)
1219    table_lower = table.lower()
1220    table_upper = table.upper()
1221    table_lower_trunc = truncate_item_name(table_lower, flavor=flavor)
1222    table_upper_trunc = truncate_item_name(table_upper, flavor=flavor)
1223    db_prefix = (
1224        "tempdb."
1225        if flavor == 'mssql' and table.startswith('#')
1226        else ""
1227    )
1228
1229    cols_types_query = sqlalchemy.text(
1230        textwrap.dedent(columns_types_queries.get(
1231            flavor,
1232            columns_types_queries['default']
1233        ).format(
1234            table=table,
1235            table_trunc=table_trunc,
1236            table_lower=table_lower,
1237            table_lower_trunc=table_lower_trunc,
1238            table_upper=table_upper,
1239            table_upper_trunc=table_upper_trunc,
1240            db_prefix=db_prefix,
1241        )).lstrip().rstrip()
1242    )
1243
1244    cols = ['database', 'schema', 'table', 'column', 'type', 'numeric_precision', 'numeric_scale']
1245    result_cols_ix = dict(enumerate(cols))
1246
1247    debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {}
1248    if not debug_kwargs and debug:
1249        dprint(cols_types_query)
1250
1251    try:
1252        result_rows = (
1253            [
1254                row
1255                for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall()
1256            ]
1257            if flavor != 'duckdb'
1258            else [
1259                tuple([doc[col] for col in cols])
1260                for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records')
1261            ]
1262        )
1263        cols_types_docs = [
1264            {
1265                result_cols_ix[i]: val
1266                for i, val in enumerate(row)
1267            }
1268            for row in result_rows
1269        ]
1270        cols_types_docs_filtered = [
1271            doc
1272            for doc in cols_types_docs
1273            if (
1274                (
1275                    not schema
1276                    or doc['schema'] == schema
1277                )
1278                and
1279                (
1280                    not database
1281                    or doc['database'] == database
1282                )
1283            )
1284        ]
1285
1286        ### NOTE: This may return incorrect columns if the schema is not explicitly stated.
1287        if cols_types_docs and not cols_types_docs_filtered:
1288            cols_types_docs_filtered = cols_types_docs
1289
1290        return {
1291            (
1292                doc['column']
1293                if flavor != 'oracle' else (
1294                    (
1295                        doc['column'].lower()
1296                        if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha())
1297                        else doc['column']
1298                    )
1299                )
1300            ): doc['type'].upper() + (
1301                f'({precision},{scale})'
1302                if (
1303                    (precision := doc.get('numeric_precision', None))
1304                     and
1305                    (scale := doc.get('numeric_scale', None))
1306                )
1307                else ''
1308            )
1309            for doc in cols_types_docs_filtered
1310        }
1311    except Exception as e:
1312        warn(f"Failed to fetch columns for table '{table}':\n{e}")
1313        return {}

Return a dictionary mapping a table's columns to data types. This is useful for inspecting tables creating during a not-yet-committed session.

NOTE: This may return incorrect columns if the schema is not explicitly stated. Use this function if you are confident the table name is unique or if you have and explicit schema. To use the configured schema, get the columns from get_sqlalchemy_table() instead.

Parameters
  • table (str): The name of the table (unquoted).
  • connectable (Union[): 'mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine'
  • ]: The connection object used to fetch the columns and types.
  • flavor (Optional[str], default None): The database dialect flavor to use for the query. If omitted, default to connectable.flavor.
  • schema (Optional[str], default None): If provided, restrict the query to this schema.
  • database (Optional[str]. default None): If provided, restrict the query to this database.
Returns
  • A dictionary mapping column names to data types.
def get_table_cols_indices( table: str, connectable: "Union['mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine']", flavor: Optional[str] = None, schema: Optional[str] = None, database: Optional[str] = None, debug: bool = False) -> Dict[str, List[str]]:
1316def get_table_cols_indices(
1317    table: str,
1318    connectable: Union[
1319        'mrsm.connectors.sql.SQLConnector',
1320        'sqlalchemy.orm.session.Session',
1321        'sqlalchemy.engine.base.Engine'
1322    ],
1323    flavor: Optional[str] = None,
1324    schema: Optional[str] = None,
1325    database: Optional[str] = None,
1326    debug: bool = False,
1327) -> Dict[str, List[str]]:
1328    """
1329    Return a dictionary mapping a table's columns to lists of indices.
1330    This is useful for inspecting tables creating during a not-yet-committed session.
1331
1332    NOTE: This may return incorrect columns if the schema is not explicitly stated.
1333        Use this function if you are confident the table name is unique or if you have
1334        and explicit schema.
1335        To use the configured schema, get the columns from `get_sqlalchemy_table()` instead.
1336
1337    Parameters
1338    ----------
1339    table: str
1340        The name of the table (unquoted).
1341
1342    connectable: Union[
1343        'mrsm.connectors.sql.SQLConnector',
1344        'sqlalchemy.orm.session.Session',
1345        'sqlalchemy.engine.base.Engine'
1346    ]
1347        The connection object used to fetch the columns and types.
1348
1349    flavor: Optional[str], default None
1350        The database dialect flavor to use for the query.
1351        If omitted, default to `connectable.flavor`.
1352
1353    schema: Optional[str], default None
1354        If provided, restrict the query to this schema.
1355
1356    database: Optional[str]. default None
1357        If provided, restrict the query to this database.
1358
1359    Returns
1360    -------
1361    A dictionary mapping column names to a list of indices.
1362    """
1363    import textwrap
1364    from collections import defaultdict
1365    from meerschaum.connectors import SQLConnector
1366    sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False)
1367    flavor = flavor or getattr(connectable, 'flavor', None)
1368    if not flavor:
1369        raise ValueError("Please provide a database flavor.")
1370    if flavor == 'duckdb' and not isinstance(connectable, SQLConnector):
1371        raise ValueError("You must provide a SQLConnector when using DuckDB.")
1372    if flavor in NO_SCHEMA_FLAVORS:
1373        schema = None
1374    if schema is None:
1375        schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None)
1376    if flavor in ('sqlite', 'duckdb', 'oracle'):
1377        database = None
1378    table_trunc = truncate_item_name(table, flavor=flavor)
1379    table_lower = table.lower()
1380    table_upper = table.upper()
1381    table_lower_trunc = truncate_item_name(table_lower, flavor=flavor)
1382    table_upper_trunc = truncate_item_name(table_upper, flavor=flavor)
1383    db_prefix = (
1384        "tempdb."
1385        if flavor == 'mssql' and table.startswith('#')
1386        else ""
1387    )
1388
1389    cols_indices_query = sqlalchemy.text(
1390        textwrap.dedent(columns_indices_queries.get(
1391            flavor,
1392            columns_indices_queries['default']
1393        ).format(
1394            table=table,
1395            table_trunc=table_trunc,
1396            table_lower=table_lower,
1397            table_lower_trunc=table_lower_trunc,
1398            table_upper=table_upper,
1399            table_upper_trunc=table_upper_trunc,
1400            db_prefix=db_prefix,
1401            schema=schema,
1402        )).lstrip().rstrip()
1403    )
1404
1405    cols = ['database', 'schema', 'table', 'column', 'index', 'index_type']
1406    if flavor == 'mssql':
1407        cols.append('clustered')
1408    result_cols_ix = dict(enumerate(cols))
1409
1410    debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {}
1411    if not debug_kwargs and debug:
1412        dprint(cols_indices_query)
1413
1414    try:
1415        result_rows = (
1416            [
1417                row
1418                for row in connectable.execute(cols_indices_query, **debug_kwargs).fetchall()
1419            ]
1420            if flavor != 'duckdb'
1421            else [
1422                tuple([doc[col] for col in cols])
1423                for doc in connectable.read(cols_indices_query, debug=debug).to_dict(orient='records')
1424            ]
1425        )
1426        cols_types_docs = [
1427            {
1428                result_cols_ix[i]: val
1429                for i, val in enumerate(row)
1430            }
1431            for row in result_rows
1432        ]
1433        cols_types_docs_filtered = [
1434            doc
1435            for doc in cols_types_docs
1436            if (
1437                (
1438                    not schema
1439                    or doc['schema'] == schema
1440                )
1441                and
1442                (
1443                    not database
1444                    or doc['database'] == database
1445                )
1446            )
1447        ]
1448        ### NOTE: This may return incorrect columns if the schema is not explicitly stated.
1449        if cols_types_docs and not cols_types_docs_filtered:
1450            cols_types_docs_filtered = cols_types_docs
1451
1452        cols_indices = defaultdict(lambda: [])
1453        for doc in cols_types_docs_filtered:
1454            col = (
1455                doc['column']
1456                if flavor != 'oracle'
1457                else (
1458                    doc['column'].lower()
1459                    if (doc['column'].isupper() and doc['column'].replace('_', '').isalpha())
1460                    else doc['column']
1461                )
1462            )
1463            index_doc = {
1464                'name': doc.get('index', None),
1465                'type': doc.get('index_type', None)
1466            }
1467            if flavor == 'mssql':
1468                index_doc['clustered'] = doc.get('clustered', None)
1469            cols_indices[col].append(index_doc)
1470
1471        return dict(cols_indices)
1472    except Exception as e:
1473        warn(f"Failed to fetch columns for table '{table}':\n{e}")
1474        return {}

Return a dictionary mapping a table's columns to lists of indices. This is useful for inspecting tables creating during a not-yet-committed session.

NOTE: This may return incorrect columns if the schema is not explicitly stated. Use this function if you are confident the table name is unique or if you have and explicit schema. To use the configured schema, get the columns from get_sqlalchemy_table() instead.

Parameters
  • table (str): The name of the table (unquoted).
  • connectable (Union[): 'mrsm.connectors.sql.SQLConnector', 'sqlalchemy.orm.session.Session', 'sqlalchemy.engine.base.Engine'
  • ]: The connection object used to fetch the columns and types.
  • flavor (Optional[str], default None): The database dialect flavor to use for the query. If omitted, default to connectable.flavor.
  • schema (Optional[str], default None): If provided, restrict the query to this schema.
  • database (Optional[str]. default None): If provided, restrict the query to this database.
Returns
  • A dictionary mapping column names to a list of indices.
def get_update_queries( target: str, patch: str, connectable: "Union[mrsm.connectors.sql.SQLConnector, 'sqlalchemy.orm.session.Session']", join_cols: Iterable[str], flavor: Optional[str] = None, upsert: bool = False, datetime_col: Optional[str] = None, schema: Optional[str] = None, patch_schema: Optional[str] = None, identity_insert: bool = False, null_indices: bool = True, cast_columns: bool = True, debug: bool = False) -> List[str]:
1477def get_update_queries(
1478    target: str,
1479    patch: str,
1480    connectable: Union[
1481        mrsm.connectors.sql.SQLConnector,
1482        'sqlalchemy.orm.session.Session'
1483    ],
1484    join_cols: Iterable[str],
1485    flavor: Optional[str] = None,
1486    upsert: bool = False,
1487    datetime_col: Optional[str] = None,
1488    schema: Optional[str] = None,
1489    patch_schema: Optional[str] = None,
1490    identity_insert: bool = False,
1491    null_indices: bool = True,
1492    cast_columns: bool = True,
1493    debug: bool = False,
1494) -> List[str]:
1495    """
1496    Build a list of `MERGE`, `UPDATE`, `DELETE`/`INSERT` queries to apply a patch to target table.
1497
1498    Parameters
1499    ----------
1500    target: str
1501        The name of the target table.
1502
1503    patch: str
1504        The name of the patch table. This should have the same shape as the target.
1505
1506    connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session]
1507        The `SQLConnector` or SQLAlchemy session which will later execute the queries.
1508
1509    join_cols: List[str]
1510        The columns to use to join the patch to the target.
1511
1512    flavor: Optional[str], default None
1513        If using a SQLAlchemy session, provide the expected database flavor.
1514
1515    upsert: bool, default False
1516        If `True`, return an upsert query rather than an update.
1517
1518    datetime_col: Optional[str], default None
1519        If provided, bound the join query using this column as the datetime index.
1520        This must be present on both tables.
1521
1522    schema: Optional[str], default None
1523        If provided, use this schema when quoting the target table.
1524        Defaults to `connector.schema`.
1525
1526    patch_schema: Optional[str], default None
1527        If provided, use this schema when quoting the patch table.
1528        Defaults to `schema`.
1529
1530    identity_insert: bool, default False
1531        If `True`, include `SET IDENTITY_INSERT` queries before and after the update queries.
1532        Only applies for MSSQL upserts.
1533
1534    null_indices: bool, default True
1535        If `False`, do not coalesce index columns before joining.
1536
1537    cast_columns: bool, default True
1538        If `False`, do not cast update columns to the target table types.
1539
1540    debug: bool, default False
1541        Verbosity toggle.
1542
1543    Returns
1544    -------
1545    A list of query strings to perform the update operation.
1546    """
1547    import textwrap
1548    from meerschaum.connectors import SQLConnector
1549    from meerschaum.utils.debug import dprint
1550    from meerschaum.utils.dtypes import are_dtypes_equal
1551    from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES, get_pd_type_from_db_type
1552    flavor = flavor or (connectable.flavor if isinstance(connectable, SQLConnector) else None)
1553    if not flavor:
1554        raise ValueError("Provide a flavor if using a SQLAlchemy session.")
1555    if (
1556        flavor == 'sqlite'
1557        and isinstance(connectable, SQLConnector)
1558        and connectable.db_version < '3.33.0'
1559    ):
1560        flavor = 'sqlite_delete_insert'
1561    flavor_key = (f'{flavor}-upsert' if upsert else flavor)
1562    base_queries = UPDATE_QUERIES.get(
1563        flavor_key,
1564        UPDATE_QUERIES['default']
1565    )
1566    if not isinstance(base_queries, list):
1567        base_queries = [base_queries]
1568    schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None)
1569    patch_schema = patch_schema or schema
1570    target_table_columns = get_table_cols_types(
1571        target,
1572        connectable,
1573        flavor=flavor,
1574        schema=schema,
1575        debug=debug,
1576    )
1577    patch_table_columns = get_table_cols_types(
1578        patch,
1579        connectable,
1580        flavor=flavor,
1581        schema=patch_schema,
1582        debug=debug,
1583    )
1584
1585    patch_cols_str = ', '.join(
1586        [
1587            sql_item_name(col, flavor)
1588            for col in patch_table_columns
1589        ]
1590    )
1591    patch_cols_prefixed_str = ', '.join(
1592        [
1593            'p.' + sql_item_name(col, flavor)
1594            for col in patch_table_columns
1595        ]
1596    )
1597
1598    join_cols_str = ', '.join(
1599        [
1600            sql_item_name(col, flavor)
1601            for col in join_cols
1602        ]
1603    )
1604
1605    value_cols = []
1606    join_cols_types = []
1607    if debug:
1608        dprint("target_table_columns:")
1609        mrsm.pprint(target_table_columns)
1610    for c_name, c_type in target_table_columns.items():
1611        if c_name not in patch_table_columns:
1612            continue
1613        if flavor in DB_FLAVORS_CAST_DTYPES:
1614            c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type)
1615        (
1616            join_cols_types
1617            if c_name in join_cols
1618            else value_cols
1619        ).append((c_name, c_type))
1620    if debug:
1621        dprint(f"value_cols: {value_cols}")
1622
1623    if not join_cols_types:
1624        return []
1625    if not value_cols and not upsert:
1626        return []
1627
1628    coalesce_join_cols_str = ', '.join(
1629        [
1630            (
1631                (
1632                    'COALESCE('
1633                    + sql_item_name(c_name, flavor)
1634                    + ', '
1635                    + get_null_replacement(c_type, flavor)
1636                    + ')'
1637                )
1638                if null_indices
1639                else sql_item_name(c_name, flavor)
1640            )
1641            for c_name, c_type in join_cols_types
1642        ]
1643    )
1644
1645    update_or_nothing = ('UPDATE' if value_cols else 'NOTHING')
1646
1647    def sets_subquery(l_prefix: str, r_prefix: str):
1648        if not value_cols:
1649            return ''
1650
1651        utc_value_cols = {
1652            c_name
1653            for c_name, c_type in value_cols
1654            if ('utc' in get_pd_type_from_db_type(c_type).lower())
1655        } if flavor not in TIMEZONE_NAIVE_FLAVORS else set()
1656
1657        cast_func_cols = {
1658            c_name: (
1659                ('', '', '')
1660                if not cast_columns or (
1661                    flavor == 'oracle'
1662                    and are_dtypes_equal(get_pd_type_from_db_type(c_type), 'bytes')
1663                )
1664                else (
1665                    ('CAST(', f" AS {c_type.replace('_', ' ')}", ')' + (
1666                        " AT TIME ZONE 'UTC'"
1667                        if c_name in utc_value_cols
1668                        else ''
1669                    ))
1670                    if flavor != 'sqlite'
1671                    else ('', '', '')
1672                )
1673            )
1674            for c_name, c_type in value_cols
1675        }
1676        return 'SET ' + ',\n'.join([
1677            (
1678                l_prefix + sql_item_name(c_name, flavor, None)
1679                + ' = '
1680                + cast_func_cols[c_name][0]
1681                + r_prefix + sql_item_name(c_name, flavor, None)
1682                + cast_func_cols[c_name][1]
1683                + cast_func_cols[c_name][2]
1684            ) for c_name, c_type in value_cols
1685        ])
1686
1687    def and_subquery(l_prefix: str, r_prefix: str):
1688        return '\n            AND\n                '.join([
1689            (
1690                (
1691                    "COALESCE("
1692                    + l_prefix
1693                    + sql_item_name(c_name, flavor, None)
1694                    + ", "
1695                    + get_null_replacement(c_type, flavor)
1696                    + ")"
1697                    + '\n                =\n                '
1698                    + "COALESCE("
1699                    + r_prefix
1700                    + sql_item_name(c_name, flavor, None)
1701                    + ", "
1702                    + get_null_replacement(c_type, flavor)
1703                    + ")"
1704                )
1705                if null_indices
1706                else (
1707                    l_prefix
1708                    + sql_item_name(c_name, flavor, None)
1709                    + ' = '
1710                    + r_prefix
1711                    + sql_item_name(c_name, flavor, None)
1712                )
1713            ) for c_name, c_type in join_cols_types
1714        ])
1715
1716    skip_query_val = ""
1717    target_table_name = sql_item_name(target, flavor, schema)
1718    patch_table_name = sql_item_name(patch, flavor, patch_schema)
1719    dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None
1720    date_bounds_table = patch_table_name if flavor != 'mssql' else '[date_bounds]'
1721    min_dt_col_name = f"MIN({dt_col_name})" if flavor != 'mssql' else '[Min_dt]'
1722    max_dt_col_name = f"MAX({dt_col_name})" if flavor != 'mssql' else '[Max_dt]'
1723    date_bounds_subquery = (
1724        f"""f.{dt_col_name} >= (SELECT {min_dt_col_name} FROM {date_bounds_table})
1725            AND
1726                f.{dt_col_name} <= (SELECT {max_dt_col_name} FROM {date_bounds_table})"""
1727        if datetime_col
1728        else "1 = 1"
1729    )
1730    with_temp_date_bounds = f"""WITH [date_bounds] AS (
1731        SELECT MIN({dt_col_name}) AS {min_dt_col_name}, MAX({dt_col_name}) AS {max_dt_col_name}
1732        FROM {patch_table_name}
1733    )""" if datetime_col else ""
1734    identity_insert_on = (
1735        f"SET IDENTITY_INSERT {target_table_name} ON"
1736        if identity_insert
1737        else skip_query_val
1738    )
1739    identity_insert_off = (
1740        f"SET IDENTITY_INSERT {target_table_name} OFF"
1741        if identity_insert
1742        else skip_query_val
1743    )
1744
1745    ### NOTE: MSSQL upserts must exclude the update portion if only upserting indices.
1746    when_matched_update_sets_subquery_none = "" if not value_cols else (
1747        "\n        WHEN MATCHED THEN\n"
1748        f"            UPDATE {sets_subquery('', 'p.')}"
1749    )
1750
1751    cols_equal_values = '\n,'.join(
1752        [
1753            f"{sql_item_name(c_name, flavor)} = VALUES({sql_item_name(c_name, flavor)})"
1754            for c_name, c_type in value_cols
1755        ]
1756    )
1757    on_duplicate_key_update = (
1758        "ON DUPLICATE KEY UPDATE"
1759        if value_cols
1760        else ""
1761    )
1762    ignore = "IGNORE " if not value_cols else ""
1763
1764    formatted_queries = [
1765        textwrap.dedent(base_query.format(
1766            sets_subquery_none=sets_subquery('', 'p.'),
1767            sets_subquery_none_excluded=sets_subquery('', 'EXCLUDED.'),
1768            sets_subquery_f=sets_subquery('f.', 'p.'),
1769            and_subquery_f=and_subquery('p.', 'f.'),
1770            and_subquery_t=and_subquery('p.', 't.'),
1771            target_table_name=target_table_name,
1772            patch_table_name=patch_table_name,
1773            patch_cols_str=patch_cols_str,
1774            patch_cols_prefixed_str=patch_cols_prefixed_str,
1775            date_bounds_subquery=date_bounds_subquery,
1776            join_cols_str=join_cols_str,
1777            coalesce_join_cols_str=coalesce_join_cols_str,
1778            update_or_nothing=update_or_nothing,
1779            when_matched_update_sets_subquery_none=when_matched_update_sets_subquery_none,
1780            cols_equal_values=cols_equal_values,
1781            on_duplicate_key_update=on_duplicate_key_update,
1782            ignore=ignore,
1783            with_temp_date_bounds=with_temp_date_bounds,
1784            identity_insert_on=identity_insert_on,
1785            identity_insert_off=identity_insert_off,
1786        )).lstrip().rstrip()
1787        for base_query in base_queries
1788    ]
1789
1790    ### NOTE: Allow for skipping some queries.
1791    return [query for query in formatted_queries if query]

Build a list of MERGE, UPDATE, DELETE/INSERT queries to apply a patch to target table.

Parameters
  • target (str): The name of the target table.
  • patch (str): The name of the patch table. This should have the same shape as the target.
  • connectable (Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session]): The SQLConnector or SQLAlchemy session which will later execute the queries.
  • join_cols (List[str]): The columns to use to join the patch to the target.
  • flavor (Optional[str], default None): If using a SQLAlchemy session, provide the expected database flavor.
  • upsert (bool, default False): If True, return an upsert query rather than an update.
  • datetime_col (Optional[str], default None): If provided, bound the join query using this column as the datetime index. This must be present on both tables.
  • schema (Optional[str], default None): If provided, use this schema when quoting the target table. Defaults to connector.schema.
  • patch_schema (Optional[str], default None): If provided, use this schema when quoting the patch table. Defaults to schema.
  • identity_insert (bool, default False): If True, include SET IDENTITY_INSERT queries before and after the update queries. Only applies for MSSQL upserts.
  • null_indices (bool, default True): If False, do not coalesce index columns before joining.
  • cast_columns (bool, default True): If False, do not cast update columns to the target table types.
  • debug (bool, default False): Verbosity toggle.
Returns
  • A list of query strings to perform the update operation.
def get_null_replacement(typ: str, flavor: str) -> str:
1794def get_null_replacement(typ: str, flavor: str) -> str:
1795    """
1796    Return a value that may temporarily be used in place of NULL for this type.
1797
1798    Parameters
1799    ----------
1800    typ: str
1801        The typ to be converted to NULL.
1802
1803    flavor: str
1804        The database flavor for which this value will be used.
1805
1806    Returns
1807    -------
1808    A value which may stand in place of NULL for this type.
1809    `'None'` is returned if a value cannot be determined.
1810    """
1811    from meerschaum.utils.dtypes import are_dtypes_equal
1812    from meerschaum.utils.dtypes.sql import DB_FLAVORS_CAST_DTYPES
1813    if 'int' in typ.lower() or typ.lower() in ('numeric', 'number'):
1814        return '-987654321'
1815    if 'bool' in typ.lower() or typ.lower() == 'bit':
1816        bool_typ = (
1817            PD_TO_DB_DTYPES_FLAVORS
1818            .get('bool', {})
1819            .get(flavor, PD_TO_DB_DTYPES_FLAVORS['bool']['default'])
1820        )
1821        if flavor in DB_FLAVORS_CAST_DTYPES:
1822            bool_typ = DB_FLAVORS_CAST_DTYPES[flavor].get(bool_typ, bool_typ)
1823        val_to_cast = (
1824            -987654321
1825            if flavor in ('mysql', 'mariadb')
1826            else 0
1827        )
1828        return f'CAST({val_to_cast} AS {bool_typ})'
1829    if 'time' in typ.lower() or 'date' in typ.lower():
1830        db_type = typ if typ.isupper() else None
1831        return dateadd_str(flavor=flavor, begin='1900-01-01', db_type=db_type)
1832    if 'float' in typ.lower() or 'double' in typ.lower() or typ.lower() in ('decimal',):
1833        return '-987654321.0'
1834    if flavor == 'oracle' and typ.lower().split('(', maxsplit=1)[0] == 'char':
1835        return "'-987654321'"
1836    if flavor == 'oracle' and typ.lower() in ('blob', 'bytes'):
1837        return '00'
1838    if typ.lower() in ('uniqueidentifier', 'guid', 'uuid'):
1839        magic_val = 'DEADBEEF-ABBA-BABE-CAFE-DECAFC0FFEE5'
1840        if flavor == 'mssql':
1841            return f"CAST('{magic_val}' AS UNIQUEIDENTIFIER)"
1842        return f"'{magic_val}'"
1843    return ('n' if flavor == 'oracle' else '') + "'-987654321'"

Return a value that may temporarily be used in place of NULL for this type.

Parameters
  • typ (str): The typ to be converted to NULL.
  • flavor (str): The database flavor for which this value will be used.
Returns
  • A value which may stand in place of NULL for this type.
  • 'None' is returned if a value cannot be determined.
def get_db_version(conn: "'SQLConnector'", debug: bool = False) -> Optional[str]:
1846def get_db_version(conn: 'SQLConnector', debug: bool = False) -> Union[str, None]:
1847    """
1848    Fetch the database version if possible.
1849    """
1850    version_name = sql_item_name('version', conn.flavor, None)
1851    version_query = version_queries.get(
1852        conn.flavor,
1853        version_queries['default']
1854    ).format(version_name=version_name)
1855    return conn.value(version_query, debug=debug)

Fetch the database version if possible.

def get_rename_table_queries( old_table: str, new_table: str, flavor: str, schema: Optional[str] = None) -> List[str]:
1858def get_rename_table_queries(
1859    old_table: str,
1860    new_table: str,
1861    flavor: str,
1862    schema: Optional[str] = None,
1863) -> List[str]:
1864    """
1865    Return queries to alter a table's name.
1866
1867    Parameters
1868    ----------
1869    old_table: str
1870        The unquoted name of the old table.
1871
1872    new_table: str
1873        The unquoted name of the new table.
1874
1875    flavor: str
1876        The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`.
1877
1878    schema: Optional[str], default None
1879        The schema on which the table resides.
1880
1881    Returns
1882    -------
1883    A list of `ALTER TABLE` or equivalent queries for the database flavor.
1884    """
1885    old_table_name = sql_item_name(old_table, flavor, schema)
1886    new_table_name = sql_item_name(new_table, flavor, None)
1887    tmp_table = '_tmp_rename_' + new_table
1888    tmp_table_name = sql_item_name(tmp_table, flavor, schema)
1889    if flavor == 'mssql':
1890        return [f"EXEC sp_rename '{old_table}', '{new_table}'"]
1891
1892    if_exists_str = "IF EXISTS" if flavor in DROP_IF_EXISTS_FLAVORS else ""
1893    if flavor == 'duckdb':
1894        return (
1895            get_create_table_queries(
1896                f"SELECT * FROM {old_table_name}",
1897                tmp_table,
1898                'duckdb',
1899                schema,
1900            ) + get_create_table_queries(
1901                f"SELECT * FROM {tmp_table_name}",
1902                new_table,
1903                'duckdb',
1904                schema,
1905            ) + [
1906                f"DROP TABLE {if_exists_str} {tmp_table_name}",
1907                f"DROP TABLE {if_exists_str} {old_table_name}",
1908            ]
1909        )
1910
1911    return [f"ALTER TABLE {old_table_name} RENAME TO {new_table_name}"]

Return queries to alter a table's name.

Parameters
  • old_table (str): The unquoted name of the old table.
  • new_table (str): The unquoted name of the new table.
  • flavor (str): The database flavor to use for the query (e.g. 'mssql', 'postgresql'.
  • schema (Optional[str], default None): The schema on which the table resides.
Returns
  • A list of ALTER TABLE or equivalent queries for the database flavor.
def get_create_table_query( query_or_dtypes: Union[str, Dict[str, str]], new_table: str, flavor: str, schema: Optional[str] = None) -> str:
1914def get_create_table_query(
1915    query_or_dtypes: Union[str, Dict[str, str]],
1916    new_table: str,
1917    flavor: str,
1918    schema: Optional[str] = None,
1919) -> str:
1920    """
1921    NOTE: This function is deprecated. Use `get_create_table_queries()` instead.
1922
1923    Return a query to create a new table from a `SELECT` query.
1924
1925    Parameters
1926    ----------
1927    query: Union[str, Dict[str, str]]
1928        The select query to use for the creation of the table.
1929        If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns.
1930
1931    new_table: str
1932        The unquoted name of the new table.
1933
1934    flavor: str
1935        The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`).
1936
1937    schema: Optional[str], default None
1938        The schema on which the table will reside.
1939
1940    Returns
1941    -------
1942    A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor.
1943    """
1944    return get_create_table_queries(
1945        query_or_dtypes,
1946        new_table,
1947        flavor,
1948        schema=schema,
1949        primary_key=None,
1950    )[0]

NOTE: This function is deprecated. Use get_create_table_queries() instead.

Return a query to create a new table from a SELECT query.

Parameters
  • query (Union[str, Dict[str, str]]): The select query to use for the creation of the table. If a dictionary is provided, return a CREATE TABLE query from the given dtypes columns.
  • new_table (str): The unquoted name of the new table.
  • flavor (str): The database flavor to use for the query (e.g. 'mssql', 'postgresql').
  • schema (Optional[str], default None): The schema on which the table will reside.
Returns
  • A CREATE TABLE (or SELECT INTO) query for the database flavor.
def get_create_table_queries( query_or_dtypes: Union[str, Dict[str, str]], new_table: str, flavor: str, schema: Optional[str] = None, primary_key: Optional[str] = None, primary_key_db_type: Optional[str] = None, autoincrement: bool = False, datetime_column: Optional[str] = None) -> List[str]:
1953def get_create_table_queries(
1954    query_or_dtypes: Union[str, Dict[str, str]],
1955    new_table: str,
1956    flavor: str,
1957    schema: Optional[str] = None,
1958    primary_key: Optional[str] = None,
1959    primary_key_db_type: Optional[str] = None,
1960    autoincrement: bool = False,
1961    datetime_column: Optional[str] = None,
1962) -> List[str]:
1963    """
1964    Return a query to create a new table from a `SELECT` query or a `dtypes` dictionary.
1965
1966    Parameters
1967    ----------
1968    query_or_dtypes: Union[str, Dict[str, str]]
1969        The select query to use for the creation of the table.
1970        If a dictionary is provided, return a `CREATE TABLE` query from the given `dtypes` columns.
1971
1972    new_table: str
1973        The unquoted name of the new table.
1974
1975    flavor: str
1976        The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`).
1977
1978    schema: Optional[str], default None
1979        The schema on which the table will reside.
1980
1981    primary_key: Optional[str], default None
1982        If provided, designate this column as the primary key in the new table.
1983
1984    primary_key_db_type: Optional[str], default None
1985        If provided, alter the primary key to this type (to set NOT NULL constraint).
1986
1987    autoincrement: bool, default False
1988        If `True` and `primary_key` is provided, create the `primary_key` column
1989        as an auto-incrementing integer column.
1990
1991    datetime_column: Optional[str], default None
1992        If provided, include this column in the primary key.
1993        Applicable to TimescaleDB only.
1994
1995    Returns
1996    -------
1997    A `CREATE TABLE` (or `SELECT INTO`) query for the database flavor.
1998    """
1999    if not isinstance(query_or_dtypes, (str, dict)):
2000        raise TypeError("`query_or_dtypes` must be a query or a dtypes dictionary.")
2001
2002    method = (
2003        _get_create_table_query_from_cte
2004        if isinstance(query_or_dtypes, str)
2005        else _get_create_table_query_from_dtypes
2006    )
2007    return method(
2008        query_or_dtypes,
2009        new_table,
2010        flavor,
2011        schema=schema,
2012        primary_key=primary_key,
2013        primary_key_db_type=primary_key_db_type,
2014        autoincrement=(autoincrement and flavor not in SKIP_AUTO_INCREMENT_FLAVORS),
2015        datetime_column=datetime_column,
2016    )

Return a query to create a new table from a SELECT query or a dtypes dictionary.

Parameters
  • query_or_dtypes (Union[str, Dict[str, str]]): The select query to use for the creation of the table. If a dictionary is provided, return a CREATE TABLE query from the given dtypes columns.
  • new_table (str): The unquoted name of the new table.
  • flavor (str): The database flavor to use for the query (e.g. 'mssql', 'postgresql').
  • schema (Optional[str], default None): The schema on which the table will reside.
  • primary_key (Optional[str], default None): If provided, designate this column as the primary key in the new table.
  • primary_key_db_type (Optional[str], default None): If provided, alter the primary key to this type (to set NOT NULL constraint).
  • autoincrement (bool, default False): If True and primary_key is provided, create the primary_key column as an auto-incrementing integer column.
  • datetime_column (Optional[str], default None): If provided, include this column in the primary key. Applicable to TimescaleDB only.
Returns
  • A CREATE TABLE (or SELECT INTO) query for the database flavor.
def wrap_query_with_cte( sub_query: str, parent_query: str, flavor: str, cte_name: str = 'src') -> str:
2257def wrap_query_with_cte(
2258    sub_query: str,
2259    parent_query: str,
2260    flavor: str,
2261    cte_name: str = "src",
2262) -> str:
2263    """
2264    Wrap a subquery in a CTE and append an encapsulating query.
2265
2266    Parameters
2267    ----------
2268    sub_query: str
2269        The query to be referenced. This may itself contain CTEs.
2270        Unless `cte_name` is provided, this will be aliased as `src`.
2271
2272    parent_query: str
2273        The larger query to append which references the subquery.
2274        This must not contain CTEs.
2275
2276    flavor: str
2277        The database flavor, e.g. `'mssql'`.
2278
2279    cte_name: str, default 'src'
2280        The CTE alias, defaults to `src`.
2281
2282    Returns
2283    -------
2284    An encapsulating query which allows you to treat `sub_query` as a temporary table.
2285
2286    Examples
2287    --------
2288
2289    ```python
2290    from meerschaum.utils.sql import wrap_query_with_cte
2291    sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo"
2292    parent_query = "SELECT newval * 3 FROM src"
2293    query = wrap_query_with_cte(sub_query, parent_query, 'mssql')
2294    print(query)
2295    # WITH foo AS (SELECT 1 AS val),
2296    # [src] AS (
2297    #     SELECT (val * 2) AS newval FROM foo
2298    # )
2299    # SELECT newval * 3 FROM src
2300    ```
2301
2302    """
2303    import textwrap
2304    sub_query = sub_query.lstrip()
2305    cte_name_quoted = sql_item_name(cte_name, flavor, None)
2306
2307    if flavor in NO_CTE_FLAVORS:
2308        return (
2309            parent_query
2310            .replace(cte_name_quoted, '--MRSM_SUBQUERY--')
2311            .replace(cte_name, '--MRSM_SUBQUERY--')
2312            .replace('--MRSM_SUBQUERY--', f"(\n{sub_query}\n) AS {cte_name_quoted}")
2313        )
2314
2315    if sub_query.lstrip().lower().startswith('with '):
2316        final_select_ix = sub_query.lower().rfind('select')
2317        return (
2318            sub_query[:final_select_ix].rstrip() + ',\n'
2319            + f"{cte_name_quoted} AS (\n"
2320            + '    ' + sub_query[final_select_ix:]
2321            + "\n)\n"
2322            + parent_query
2323        )
2324
2325    return (
2326        f"WITH {cte_name_quoted} AS (\n"
2327        f"{textwrap.indent(sub_query, '    ')}\n"
2328        f")\n{parent_query}"
2329    )

Wrap a subquery in a CTE and append an encapsulating query.

Parameters
  • sub_query (str): The query to be referenced. This may itself contain CTEs. Unless cte_name is provided, this will be aliased as src.
  • parent_query (str): The larger query to append which references the subquery. This must not contain CTEs.
  • flavor (str): The database flavor, e.g. 'mssql'.
  • cte_name (str, default 'src'): The CTE alias, defaults to src.
Returns
  • An encapsulating query which allows you to treat sub_query as a temporary table.
Examples
from meerschaum.utils.sql import wrap_query_with_cte
sub_query = "WITH foo AS (SELECT 1 AS val) SELECT (val * 2) AS newval FROM foo"
parent_query = "SELECT newval * 3 FROM src"
query = wrap_query_with_cte(sub_query, parent_query, 'mssql')
print(query)
# WITH foo AS (SELECT 1 AS val),
# [src] AS (
#     SELECT (val * 2) AS newval FROM foo
# )
# SELECT newval * 3 FROM src
def format_cte_subquery( sub_query: str, flavor: str, sub_name: str = 'src', cols_to_select: Union[List[str], str] = '*') -> str:
2332def format_cte_subquery(
2333    sub_query: str,
2334    flavor: str,
2335    sub_name: str = 'src',
2336    cols_to_select: Union[List[str], str] = '*',
2337) -> str:
2338    """
2339    Given a subquery, build a wrapper query that selects from the CTE subquery.
2340
2341    Parameters
2342    ----------
2343    sub_query: str
2344        The subquery to wrap.
2345
2346    flavor: str
2347        The database flavor to use for the query (e.g. `'mssql'`, `'postgresql'`.
2348
2349    sub_name: str, default 'src'
2350        If possible, give this name to the CTE (must be unquoted).
2351
2352    cols_to_select: Union[List[str], str], default ''
2353        If specified, choose which columns to select from the CTE.
2354        If a list of strings is provided, each item will be quoted and joined with commas.
2355        If a string is given, assume it is quoted and insert it into the query.
2356
2357    Returns
2358    -------
2359    A wrapper query that selects from the CTE.
2360    """
2361    quoted_sub_name = sql_item_name(sub_name, flavor, None)
2362    cols_str = (
2363        cols_to_select
2364        if isinstance(cols_to_select, str)
2365        else ', '.join([sql_item_name(col, flavor, None) for col in cols_to_select])
2366    )
2367    parent_query = (
2368        f"SELECT {cols_str}\n"
2369        f"FROM {quoted_sub_name}"
2370    )
2371    return wrap_query_with_cte(sub_query, parent_query, flavor, cte_name=sub_name)

Given a subquery, build a wrapper query that selects from the CTE subquery.

Parameters
  • sub_query (str): The subquery to wrap.
  • flavor (str): The database flavor to use for the query (e.g. 'mssql', 'postgresql'.
  • sub_name (str, default 'src'): If possible, give this name to the CTE (must be unquoted).
  • cols_to_select (Union[List[str], str], default ''): If specified, choose which columns to select from the CTE. If a list of strings is provided, each item will be quoted and joined with commas. If a string is given, assume it is quoted and insert it into the query.
Returns
  • A wrapper query that selects from the CTE.
def session_execute( session: "'sqlalchemy.orm.session.Session'", queries: Union[List[str], str], with_results: bool = False, debug: bool = False) -> "Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]":
2374def session_execute(
2375    session: 'sqlalchemy.orm.session.Session',
2376    queries: Union[List[str], str],
2377    with_results: bool = False,
2378    debug: bool = False,
2379) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]:
2380    """
2381    Similar to `SQLConnector.exec_queries()`, execute a list of queries
2382    and roll back when one fails.
2383
2384    Parameters
2385    ----------
2386    session: sqlalchemy.orm.session.Session
2387        A SQLAlchemy session representing a transaction.
2388
2389    queries: Union[List[str], str]
2390        A query or list of queries to be executed.
2391        If a query fails, roll back the session.
2392
2393    with_results: bool, default False
2394        If `True`, return a list of result objects.
2395
2396    Returns
2397    -------
2398    A `SuccessTuple` indicating the queries were successfully executed.
2399    If `with_results`, return the `SuccessTuple` and a list of results.
2400    """
2401    sqlalchemy = mrsm.attempt_import('sqlalchemy', lazy=False)
2402    if not isinstance(queries, list):
2403        queries = [queries]
2404    successes, msgs, results = [], [], []
2405    for query in queries:
2406        if debug:
2407            dprint(query)
2408        query_text = sqlalchemy.text(query)
2409        fail_msg = "Failed to execute queries."
2410        try:
2411            result = session.execute(query_text)
2412            query_success = result is not None
2413            query_msg = "Success" if query_success else fail_msg
2414        except Exception as e:
2415            query_success = False
2416            query_msg = f"{fail_msg}\n{e}"
2417            result = None
2418        successes.append(query_success)
2419        msgs.append(query_msg)
2420        results.append(result)
2421        if not query_success:
2422            if debug:
2423                dprint("Rolling back session.")
2424            session.rollback()
2425            break
2426    success, msg = all(successes), '\n'.join(msgs)
2427    if with_results:
2428        return (success, msg), results
2429    return success, msg

Similar to SQLConnector.exec_queries(), execute a list of queries and roll back when one fails.

Parameters
  • session (sqlalchemy.orm.session.Session): A SQLAlchemy session representing a transaction.
  • queries (Union[List[str], str]): A query or list of queries to be executed. If a query fails, roll back the session.
  • with_results (bool, default False): If True, return a list of result objects.
Returns
  • A SuccessTuple indicating the queries were successfully executed.
  • If with_results, return the SuccessTuple and a list of results.
def get_reset_autoincrement_queries( table: str, column: str, connector: meerschaum.connectors.SQLConnector, schema: Optional[str] = None, debug: bool = False) -> List[str]:
2432def get_reset_autoincrement_queries(
2433    table: str,
2434    column: str,
2435    connector: mrsm.connectors.SQLConnector,
2436    schema: Optional[str] = None,
2437    debug: bool = False,
2438) -> List[str]:
2439    """
2440    Return a list of queries to reset a table's auto-increment counter to the next largest value.
2441
2442    Parameters
2443    ----------
2444    table: str
2445        The name of the table on which the auto-incrementing column exists.
2446
2447    column: str
2448        The name of the auto-incrementing column.
2449
2450    connector: mrsm.connectors.SQLConnector
2451        The SQLConnector to the database on which the table exists.
2452
2453    schema: Optional[str], default None
2454        The schema of the table. Defaults to `connector.schema`.
2455
2456    Returns
2457    -------
2458    A list of queries to be executed to reset the auto-incrementing column.
2459    """
2460    if not table_exists(table, connector, schema=schema, debug=debug):
2461        return []
2462
2463    schema = schema or connector.schema
2464    max_id_name = sql_item_name('max_id', connector.flavor)
2465    table_name = sql_item_name(table, connector.flavor, schema)
2466    table_seq_name = sql_item_name(table + '_' + column + '_seq', connector.flavor, schema)
2467    column_name = sql_item_name(column, connector.flavor)
2468    max_id = connector.value(
2469        f"""
2470        SELECT COALESCE(MAX({column_name}), 0) AS {max_id_name}
2471        FROM {table_name}
2472        """,
2473        debug=debug,
2474    )
2475    if max_id is None:
2476        return []
2477
2478    reset_queries = reset_autoincrement_queries.get(
2479        connector.flavor,
2480        reset_autoincrement_queries['default']
2481    )
2482    if not isinstance(reset_queries, list):
2483        reset_queries = [reset_queries]
2484
2485    return [
2486        query.format(
2487            column=column,
2488            column_name=column_name,
2489            table=table,
2490            table_name=table_name,
2491            table_seq_name=table_seq_name,
2492            val=max_id,
2493            val_plus_1=(max_id + 1),
2494        )
2495        for query in reset_queries
2496    ]

Return a list of queries to reset a table's auto-increment counter to the next largest value.

Parameters
  • table (str): The name of the table on which the auto-incrementing column exists.
  • column (str): The name of the auto-incrementing column.
  • connector (mrsm.connectors.SQLConnector): The SQLConnector to the database on which the table exists.
  • schema (Optional[str], default None): The schema of the table. Defaults to connector.schema.
Returns
  • A list of queries to be executed to reset the auto-incrementing column.