Find the list of attributes used in GROUP BY/ORDER BY

List the columns used in GROUP BY and ORDER BY clauses.

import pandas as pd

from flowhigh.model import *
from flowhigh.utils import FlowHighSubmissionClass


# Ds type classification
TABLE_TYPES = {
    'physical': 'Physical Table',
    'table': 'Physical Table',
    'cte': 'Common Table Expression',
    'inline': 'Inline View',
    'recursive': 'Recurstive CTE',
    'pseudo': 'Pseudo Table',
    'udtf': 'Table Function'
}


def get_col_name(col: Attr):
    """
    Return the column identifier
    :param col: the Attr node
    :return: the attribute's identifier
    """
    al = fh.get_object_from_dbohier(col.oref).name if col.oref else col.refatt
    if col.alias and 'Col$' not in col.alias:
        # check if col has an alias
        al = f'{al} AS {col.alias}'
    return al.upper()


def get_ds_name(ds: Ds):
    """
    Return the ds identifier
    :param ds: the Ds node
    :return: The ds identifier
    """
    if not ds.oref:
        return (ds.alias or fh.get_node_raw_text(ds) or '').upper()
    al = fh.get_DBO_fullref(fh.get_object_from_dbohier(ds.oref))
    if ds.alias:
        # check if col has an alias
        al = f'{al} AS {ds.alias}'
    return al.upper()


def collect_ds(source_obj: TreeNode, accum: list):
    """
    Collect any Ds type object
    :param source_obj: current node in hierarchy
    :param accum: the list where to collect the Ds objects
    :return:
    """
    # filter only Dataset nodes
    if not source_obj:
        return
    if isinstance(source_obj, Ds):
        accum.append(source_obj)
    for child in source_obj.get_children():
        collect_ds(child, accum)


def collect_attr(source_obj: TreeNode, accum: list):
    """
    Collect any Attr type object
    :param source_obj: current node in hierarchy
    :param accum: the list where to collect the Attr objects
    :return:
    """
    if not source_obj:
        return
    # filter only Attr nodes
    if isinstance(source_obj, Attr) and not source_obj.refvar:
        accum.append(source_obj)
        return
    if isinstance(source_obj, Ds):
        # collect only out attributes and expressions
        return collect_attr(source_obj.out, accum)
    for child in source_obj.get_children():
        collect_attr(child, accum)


def is_parent_ds(ds, col):
    """
    Check if column is a ds' attribute
    :param ds: current Ds
    :param col: attribute or expression to link
    :return: True if `col` belongs to `ds`
    """
    if ds.oref:
        # calculate the DBO's fully qualified name
        ds_dbo = fh.get_object_from_dbohier(ds.oref)
        ds_fullref = fh.get_DBO_fullref(ds_dbo)
    else:
        ds_fullref = ds.fullref

    if col.oref:
        # calculate the DBO's fully qualified name
        attr_dbo = fh.get_object_from_dbohier(col.oref)
        tab = attr_dbo.get_parent()
        attr_dsref = fh.get_DBO_fullref(tab)
    else:
        attr_dsref = '.'.join([x for x in [col.refdb, col.refsch, col.refds] if x])
    # check if the Attr fullref match the ds alias or its fully qualified name
    return attr_dsref and attr_dsref.casefold() in [ref.casefold() for ref in (ds.alias, ds_fullref) if ref]


def get_parent_ds(ds_list, attr):
    """
    Link the column to its parent ds
    :param ds_list: list of Ds to match
    :param attr: attribute or expression to link
    :return: any match
    """
    return next((fh.search_node_by_pos(c.sref) if c.sref else c for c in ds_list if is_parent_ds(c, attr)), None)


def collect_group_and_sort_attributes(source_obj: TreeNode):
    """
    Return the list of all the Attributes used in the GROUP BY or ORDER BY clause
    :param source_obj: TreeNode to be visited
    :return:
    """
    if not source_obj:
        return
    # filter only Sort nodes (ORDER BY)
    if isinstance(source_obj, Sort):
        out_ds: Out = fh.find_ancestor_of_type(source_obj, (Ds,)).out
        exprs = source_obj.exprs
        for e in exprs:
            for attr in traverse_and_collect(collect_attr, e):
                if attr.refoutidx:  # eg. ORDER BY 1 / Need to find the referenced out attribute
                    exprs.append(out_ds.exprs[int(attr.refoutidx)-1])
                    continue
                if attr.sref:
                    ref_alias = fh.search_origin_reference(attr.sref)
                    exprs.append(ref_alias)
                    continue
                # {Attr -> Parent Ds}
                accum_sort[attr] = get_parent_ds(ds_list, attr)
    # filter only Agg nodes (GROUP BY)
    if isinstance(source_obj, Agg):
        out_ds: Out = fh.find_ancestor_of_type(source_obj, (Ds,)).out
        exprs = source_obj.exprs
        for e in exprs:
            for attr in traverse_and_collect(collect_attr, e):
                if attr.refoutidx:  # eg. GROUP BY 1 / Need to find the referenced out attribute
                    exprs.append(out_ds.exprs[int(attr.refoutidx)-1])
                    continue
                # {Attr -> Parent Ds}
                if attr.sref:
                    exprs.append(fh.search_origin_reference(attr.sref))
                    continue
                # {Attr -> Parent Ds}
                accum_group[attr] = get_parent_ds(ds_list, attr)

    for child in source_obj.get_children():
        collect_group_and_sort_attributes(child)


def traverse_and_collect(func: callable, *args):
    """
    Utility function used to call an input function recursively (for all the nodes in the tree)
    :param func: the function to call recursively
    :param args: list of optional arguments to pass to the function
    :return:
    """
    l = []
    func(*args, l)
    return l


def aggregate_sort(agg: dict):
    """
    Aggregate the collected info in a pandas dataframe
    :param agg: the dictionary with the list of Ds and their columns
    :return: a new Pandas Dataframe
    """
    rows = []
    for k, v in agg.items():
        d = {
            "column name": get_col_name(k),                                             # column name
            "table type": TABLE_TYPES.get(v.type_ or v.subType, 'Physical Table'),      # ds type (e.g. root/cte/...)
            "direction": k.direction or 'ASC',                                          # ASC or DESC
            "table name": get_ds_name(v) if v else None                                 # ds subtype (e.g. inline/physical/pseudo/...)
        }
        if d not in rows:
            rows.append(d)
    df = pd.DataFrame(data=rows)
    return df


def aggregate_group(agg: dict):
    """
    Aggregate the collected info in a pandas dataframe
    :param agg: the dictionary with the list of Ds and their columns
    :return: a new Pandas Dataframe
    """
    rows = []
    for k, v in agg.items():
        d = {
            "column name": get_col_name(k),                                              # column name
            "table type": TABLE_TYPES.get(v.type_ or v.subType, 'Physical Table'),       # ds type (e.g. root/cte/...)
            "direction": k.direction or 'ASC',                                           # ASC or DESC
            "table name": get_ds_name(v) if v else None                                  # ds subtype (e.g. inline/physical/pseudo/...)
        }
        if d not in rows:
            rows.append(d)
    df = pd.DataFrame(data=rows)
    return df


# The SQL query to be parsed
sql = """%sql%"""

# Initializing the SDK
fh = FlowHighSubmissionClass.from_sql(sql)

accum_sort = {}
accum_group = {}
# loop over the list of input Statements
for statement in fh.get_statements():
    ds_list = traverse_and_collect(collect_ds, statement)
    collect_group_and_sort_attributes(statement)
    print(aggregate_sort(accum_sort))
    print(aggregate_group(accum_group))