Check filter expressions and joins

As for Group by and Order By, columns used in joins and filters can be good candidates for indexes as they can speed up performance of your queries.

Filters

import pandas as pd

from flowhigh.model import *
from flowhigh.model.OpType import OpType
from flowhigh.utils import FlowHighSubmissionClass


# list of LOGICAL Operators and connectors to be unpacked
LOGICAL = [OpType.OR.name, OpType.AND.name]

# Ds type classification
TABLE_TYPES = {
    'physical': 'Physical Table',
    'table': 'Physical Table',
    'cte': 'Common Table Expression',
    'inline': 'Inline View',
    'recursive': 'Recurstive CTE',
    'pseudo': 'Pseudo Table',
    'udtf': 'Table Function'
}


def get_col_name(col: Attr):
    """
    Return the column identifier
    :param col: the Attr node
    :return: the attribute's identifier
    """
    al = fh.get_object_from_dbohier(col.oref).name if col.oref else col.refatt
    if col.alias and 'Col$' not in col.alias:
        # check if col has an alias
        al = f'{al} AS {col.alias}'
    return al.upper()


def get_ds_name(ds: Ds):
    """
    Return the ds identifier
    :param ds: the Ds node
    :return: The ds identifier
    """
    if not ds.oref:
        return (ds.alias or fh.get_node_raw_text(ds) or '').upper()
    al = fh.get_DBO_fullref(fh.get_object_from_dbohier(ds.oref))
    if ds.alias:
        # check if col has an alias
        al = f'{al} AS {ds.alias}'
    return al.upper()


def collect_ds(source_obj: TreeNode, accum: list):
    """
    Collect any Ds type object
    :param source_obj: current node in hierarchy
    :param accum: the list where to collect the Ds objects
    :return:
    """
    # filter only Dataset nodes
    if not source_obj:
        return
    if isinstance(source_obj, Ds):
        accum.append(source_obj)
    for child in source_obj.get_children():
        collect_ds(child, accum)


def collect_attr(source_obj: TreeNode, accum: list):
    """
    Collect any Attr type object
    :param source_obj: current node in hierarchy
    :param accum: the list where to collect the Attr objects
    :return:
    """
    if not source_obj:
        return
    # filter only Attr nodes
    if isinstance(source_obj, Attr) and not source_obj.refvar:
        accum.append(source_obj)
        return
    if isinstance(source_obj, Ds):
        # collect only out attributes and expressions
        return collect_attr(source_obj.out, accum)
    for child in source_obj.get_children():
        collect_attr(child, accum)


def is_parent_ds(ds, col):
    """
    Check if column is a ds' attribute
    :param ds: current Ds
    :param col: attribute or expression to link
    :return: True if `col` belongs to `ds`
    """
    if ds.oref:
        # calculate the DBO's fully qualified name
        ds_dbo = fh.get_object_from_dbohier(ds.oref)
        ds_fullref = fh.get_DBO_fullref(ds_dbo)
    else:
        ds_fullref = ds.fullref

    if col.oref:
        # calculate the DBO's fully qualified name
        attr_dbo = fh.get_object_from_dbohier(col.oref)
        tab = attr_dbo.get_parent()
        attr_dsref = fh.get_DBO_fullref(tab)
    else:
        attr_dsref = '.'.join([x for x in [col.refdb, col.refsch, col.refds] if x])
    # check if the Attr fullref match the ds alias or its fully qualified name
    return attr_dsref and attr_dsref.casefold() in [ref.casefold() for ref in (ds.alias, ds_fullref) if ref]


def get_parent_ds(ds_list, attr):
    """
    Link the column to its parent ds
    :param ds_list: list of Ds to match
    :param attr: attribute or expression to link
    :return: any match
    """
    return next((fh.search_node_by_pos(c.sref) if c.sref else c for c in ds_list if is_parent_ds(c, attr)), None)


def collect_regular_and_vertical_filter(f: Filter):
    """
    Collect the list of all the Attributes used in WHERE and QUALIFY clauses
    :param f: Filter node to be visited
    :return:
    """
    exprs = [f.op]
    for op in exprs:
        # unpack nested AND/OR expressions
        if getattr(op, 'type_', None) in LOGICAL:
            exprs.extend(op.exprs)
            continue
        lft_attrs = traverse_and_collect(collect_attr, op.exprs[0])
        rt_attrs = traverse_and_collect(collect_attr, op.exprs[-1])
        for attr in lft_attrs + rt_attrs:
            if len(op.exprs) == 1 or not (lft_attrs and rt_attrs):
                # ie. exclude non-ansi joins
                # {Attr -> Parent Ds}
                accum[attr] = get_parent_ds(ds_list, attr)


def collect_aggreg_filter(f: Filter):
    """
    Collect the list of all the Attributes used in HAVING
    :param f: Filter node to be visited
    :return:
    """
    exprs = [f.op]
    for op in exprs:
        for attr in traverse_and_collect(collect_attr, op):
            # select department_id, count(department_id) x from employees group by department_id having x < 10;
            if attr.sref:
                exprs.append(fh.search_origin_reference(attr.sref))
                continue
            # {Attr -> Parent Ds}
            accum[attr] = get_parent_ds(ds_list, attr)


def collect_filter_attributes(source_obj: TreeNode):
    """
    Collect the list of all the Attributes used in WHERE
    :param source_obj: TreeNode to be visited
    :return:
    """
    if not source_obj:
        return
    # check filters used in ON clauses (e.g. from a join b on a.id = 1)
    if isinstance(source_obj, Join) and source_obj.op:
        exprs = [source_obj.op]  # NOQA: ON clause
        for op in exprs:
            # unpack nested AND/OR expressions
            if getattr(op, 'type_', None) in LOGICAL:
                exprs.extend(op.exprs)
                continue
            lft_attrs = traverse_and_collect(collect_attr, op.exprs[0])
            rt_attrs = traverse_and_collect(collect_attr, op.exprs[-1])
            if len(op.exprs) == 1 or not (lft_attrs and rt_attrs):   # from a join b on a.id = 1
                for attr in lft_attrs + rt_attrs:
                    # {Attr -> Parent Ds}
                    accum[attr] = get_parent_ds(ds_list, attr)
    # get only Dataset nodes with a filter
    if isinstance(source_obj, Ds) and source_obj.modifiers:
        [collect_regular_and_vertical_filter(filter_) for filter_ in source_obj.modifiers
         if isinstance(filter_, Filter) and filter_.type_ in ('filtreg', 'filtvert')]
        [collect_aggreg_filter(filter_) for filter_ in source_obj.modifiers
         if isinstance(filter_, Filter) and filter_.type_ == 'filtagg']

    for child in source_obj.get_children():
        collect_filter_attributes(child)


def traverse_and_collect(func: callable, *args):
    """
    Utility function used to call an input function recursively (for all the nodes in the tree)
    :param func: the function to call recursively
    :param args: list of optional arguments to pass to the function
    :return:
    """
    l = []
    func(*args, l)
    return l


def aggregate(agg: dict):
    """
   Aggregate the collected info in a pandas dataframe
   :param agg: the dictionary with the list of Ds and their columns
   :return: a new Pandas Dataframe
   """
    rows = []
    for k, v in agg.items():
        d = {
            "column name": get_col_name(k),                                             # column name
            "table type": TABLE_TYPES.get(v.type_ or v.subType, 'Physical Table'),      # ds type (e.g. root/cte/...)            # ds type (e.g. root/cte/...)
            "table name": get_ds_name(v) if v else None                                 # ds subtype (e.g. inline/physical/pseudo/...)
        }
        rows.append(d)
    df = pd.DataFrame(data=rows)
    return df


# Your SQL query from editor
sql = """%sqlEditor%"""

# Initializing the SDK
fh = FlowHighSubmissionClass.from_sql(sql)

accum = {}
# loop over the list of input Statements
for statement in fh.get_statements():
    ds_list = traverse_and_collect(collect_ds, statement)
    collect_filter_attributes(statement)
    print(aggregate(accum))

Joins

import pandas as pd

from flowhigh.model import *
from flowhigh.model.OpType import OpType
from flowhigh.utils import FlowHighSubmissionClass


# list of LOGICAL Operators and connectors to be unpacked
LOGICAL = [OpType.OR.name, OpType.AND.name]

# Ds type classification
TABLE_TYPES = {
    'physical': 'Physical Table',
    'table': 'Physical Table',
    'cte': 'Common Table Expression',
    'inline': 'Inline View',
    'recursive': 'Recurstive CTE',
    'pseudo': 'Pseudo Table',
    'udtf': 'Table Function'
}


def get_col_name(col: Attr):
    """
    Return the column identifier
    :param col: the Attr node
    :return: the attribute's identifier
    """
    al = fh.get_object_from_dbohier(col.oref).name if col.oref else col.refatt
    if col.alias and 'Col$' not in col.alias:
        # check if col has an alias
        al = f'{al} AS {col.alias}'
    return al.upper()


def get_ds_name(ds: Ds):
    """
    Return the ds identifier
    :param ds: the Ds node
    :return: The ds identifier
    """
    if not ds.oref:
        return (ds.alias or fh.get_node_raw_text(ds) or '').upper()
    al = fh.get_DBO_fullref(fh.get_object_from_dbohier(ds.oref))
    if ds.alias:
        # check if col has an alias
        al = f'{al} AS {ds.alias}'
    return al.upper()


def collect_ds(source_obj: TreeNode, accum: list):
    """
    Collect any Ds type object
    :param source_obj: current node in hierarchy
    :param accum: the list where to collect the Ds objects
    :return:
    """
    # filter only Dataset nodes
    if not source_obj:
        return
    if isinstance(source_obj, Ds):
        accum.append(source_obj)
    for child in source_obj.get_children():
        collect_ds(child, accum)


def collect_attr(source_obj: TreeNode, accum: list):
    """
    Collect any Attr type object
    :param source_obj: current node in hierarchy
    :param accum: the list where to collect the Attr objects
    :return:
    """
    if not source_obj:
        return
    # filter only Attr nodes
    if isinstance(source_obj, Attr) and not source_obj.refvar:
        accum.append(source_obj)
        return
    if isinstance(source_obj, Ds):
        # collect only out attributes and expressions
        return collect_attr(source_obj.out, accum)
    for child in source_obj.get_children():
        collect_attr(child, accum)


def is_parent_ds(ds, col):
    """
    Check if column is a ds' attribute
    :param ds: current Ds
    :param col: attribute or expression to link
    :return: True if `col` belongs to `ds`
    """
    if ds.oref:
        # calculate the DBO's fully qualified name
        ds_dbo = fh.get_object_from_dbohier(ds.oref)
        ds_fullref = fh.get_DBO_fullref(ds_dbo)
    else:
        ds_fullref = ds.fullref

    if col.oref:
        # calculate the DBO's fully qualified name
        attr_dbo = fh.get_object_from_dbohier(col.oref)
        tab = attr_dbo.get_parent()
        attr_dsref = fh.get_DBO_fullref(tab)
    else:
        attr_dsref = '.'.join([x for x in [col.refdb, col.refsch, col.refds] if x])
    # check if the Attr fullref match the ds alias or its fully qualified name
    return attr_dsref and attr_dsref.casefold() in [ref.casefold() for ref in (ds.alias, ds_fullref) if ref]


def get_parent_ds(ds_list, attrs):
    """
    Link all the columns to their parent ds
    :param ds_list: list of Ds to match
    :param attrs: attributes or expressions to link
    :return: all the matches
    """
    return [fh.search_node_by_pos(c.sref) if c.sref else c for attr in attrs
            for c in ds_list if is_parent_ds(c, attr)]


def collect_cross_join(source_obj: Join):
    """
    :param source_obj: Join to be visited
    :return:
    """
    in_: In = source_obj.get_parent()
    lft_ds = in_.exprs[in_.exprs.index(source_obj)-1]
    rt_ds = source_obj.ds
    d = {
        "lft column name": '*',
        "lft table type": [TABLE_TYPES.get(lft_ds.type_ or lft_ds.subType, 'Physical Table')],
        "lft table name": [get_ds_name(lft_ds)],
        "join type": "cross",
        "join subtype": "",
        "rt column name": '*',
        "rt table type": [TABLE_TYPES.get(rt_ds.type_ or rt_ds.subType, 'Physical Table')],
        "rt table name": [get_ds_name(rt_ds)]
    }
    accum.append(d)


def collect_join_columns(source_obj: TreeNode):
    """
    Return the list of all the Attributes used in JOIN
    :param source_obj: TreeNode to be visited
    :return:
    """
    if not source_obj:
        return
    # get only Join nodes
    if isinstance(source_obj, Join) and source_obj.type_ == 'cross':
        # collect Cartesian Products
        return collect_cross_join(source_obj)
    if isinstance(source_obj, Join) and source_obj.op:
        # inner/outer joins
        exprs = [source_obj.op]  # ON clause
        for op in exprs:
            # collect all the expressions used in the ON clause (and unpack AND/OR)
            if getattr(op, 'type_', None) in LOGICAL:
                exprs.extend(op.exprs)
                continue
            # get attributes used as the left member of the expression
            lft_attrs = traverse_and_collect(collect_attr, op.exprs[0])
            # get attributes used as the right member of the expression
            rt_attrs = traverse_and_collect(collect_attr, op.exprs[-1])
            # get ds used as the left member of the expression
            lft_ds = get_parent_ds(ds_list, lft_attrs)
            # get ds used as the right member of the expression
            rt_ds = get_parent_ds(ds_list, rt_attrs)
            if lft_ds and rt_ds and lft_ds != rt_ds:
                d = {
                    "lft column name": set(map(lambda k: get_col_name(k), lft_attrs)),
                    "lft table type": set(map(lambda v: TABLE_TYPES.get(v.type_ or v.subType, 'Physical Table'), lft_ds)),
                    "lft table name": set(map(lambda v: get_ds_name(v), lft_ds)),
                    "join type": source_obj.type_,
                    "join subtype": source_obj.subType or "",
                    "rt column name": set(map(lambda k: get_col_name(k), rt_attrs)),
                    "rt table type": set(map(lambda v: TABLE_TYPES.get(v.type_ or v.subType, 'Physical Table'), rt_ds)),
                    "rt table name": set(map(lambda v: get_ds_name(v), rt_ds))
                }
                accum.append(d)
    # detect implicit/Non-ANSI Joins (eg. WHERE a.id = b.id)
    # get only Dataset nodes with a filter
    if isinstance(source_obj, Ds) and source_obj.modifiers:
        exprs = [filter_.op for filter_ in source_obj.modifiers if isinstance(filter_, Filter) and filter_.type_ == 'filtreg']
        for op in exprs:
            # collect all the expressions used in the WHERE clause (and unpack AND/OR)
            if getattr(op, 'type_', None) in LOGICAL:
                exprs.extend(op.exprs)
                continue
            # get attributes used as the left member of the expression
            lft_attrs = traverse_and_collect(collect_attr, op.exprs[0])
            # get attributes used as the right member of the expression
            rt_attrs = traverse_and_collect(collect_attr, op.exprs[-1])
            # get ds used as the left member of the expression
            lft_ds = get_parent_ds(ds_list, lft_attrs)
            # get ds used as the right member of the expression
            rt_ds = get_parent_ds(ds_list, rt_attrs)
            if lft_attrs and rt_attrs and lft_ds != rt_ds:
                d = {
                    "lft column name": set(map(lambda k: get_col_name(k), lft_attrs)),
                    "lft table type": set(map(lambda v: TABLE_TYPES.get(v.type_ or v.subType, 'Physical Table'), lft_ds)),
                    "lft table name": set(map(lambda v: get_ds_name(v), lft_ds)),
                    "join type": 'Non Ansi',
                    "join subtype": "",
                    "rt column name": set(map(lambda k: get_col_name(k), rt_attrs)),
                    "rt table type": set(map(lambda v: TABLE_TYPES.get(v.type_ or v.subType, 'Physical Table'), rt_ds)),
                    "rt table name": set(map(lambda v: get_ds_name(v), rt_ds))
                }
                accum.append(d)

    for child in source_obj.get_children():
        collect_join_columns(child)


def traverse_and_collect(func: callable, *args):
    """
    Utility function used to call an input function recursively (for all the nodes in the tree)
    :param func: the function to call recursively
    :param args: list of optional arguments to pass to the function
    :return:
    """
    l = []
    func(*args, l)
    return l


# The SQL query to be parsed
sql = """%sql%"""

# Initializing the SDK
fh = FlowHighSubmissionClass.from_sql(sql)

# collection where to store the list of attributes used
accum = []
# loop over the list of input Statements
for statement in fh.get_statements():
    ds_list = traverse_and_collect(collect_ds, statement)
    collect_join_columns(statement)

df = pd.DataFrame(data=accum)
print(df)