Skip to content

Formatter

Formats a DAX expression into a standardized format.

Source code in pbi_parsers/dax/formatter.py
class Formatter:
    """Formats a DAX expression into a standardized format."""

    def __init__(self, expression: "Expression") -> None:
        self.expression = expression

    def format(self) -> str:
        return self._format_helper(self.expression)

    @classmethod
    def _format_helper(cls, expr: Expression) -> str:
        mapper: Any = {
            AddSubExpression: cls._format_add_sub,
            AddSubUnaryExpression: cls._format_add_sub_unary,
            ArrayExpression: cls._format_array,
            ComparisonExpression: cls._format_comparison,
            ColumnExpression: cls._format_column,
            ConcatenationExpression: cls._format_concatenation,
            DivMulExpression: cls._format_div_mul,
            ExponentExpression: cls._format_exponent,
            FunctionExpression: cls._format_function,
            HierarchyExpression: cls._format_hierarchy,
            IdentifierExpression: cls._format_identifier,
            InExpression: cls._format_in,
            KeywordExpression: cls._format_keyword,
            LiteralNumberExpression: cls._format_literal_number,
            LiteralStringExpression: cls._format_literal_string,
            LogicalExpression: cls._format_logical,
            MeasureExpression: cls._format_measure,
            NoneExpression: lambda _: "",
            ParenthesesExpression: cls._format_parens,
            ReturnExpression: cls._format_return,
            TableExpression: cls._format_table,
            VariableExpression: cls._format_variable,
        }
        if type(expr) in mapper:
            base_format = mapper[type(expr)](expr)
            if expr.pre_comments:
                base_format = f"{format_comments(expr.pre_comments, 0)}\n{base_format}"
            if expr.post_comments:
                base_format = f"{base_format}  {format_comments(expr.post_comments, 0)}"
            return base_format

        msg = f"Unsupported expression type: {type(expr).__name__}"
        raise TypeError(msg)

    @classmethod
    def _format_add_sub(cls, expr: AddSubExpression) -> str:
        left = cls._format_helper(expr.left)
        right = cls._format_helper(expr.right)
        return f"""{left} {expr.operator.text} {right}"""

    @classmethod
    def _format_add_sub_unary(cls, expr: AddSubUnaryExpression) -> str:
        return f"{expr.operator.text}{cls._format_helper(expr.number)}"

    @classmethod
    def _format_array(cls, expr: ArrayExpression) -> str:
        elements = ",\n".join(cls._format_helper(el) for el in expr.elements)
        elements = textwrap.indent(elements, " " * 4)[4:]
        return f"""{{
    {elements}
}}
"""

    @classmethod
    def _format_column(cls, expr: ColumnExpression) -> str:
        table = expr.table.text
        if table.startswith("'") and all(c in string.ascii_letters + string.digits + "_" for c in table[1:-1]):
            table = table[1:-1]
        column = expr.column.text
        return f"{table}{column}"

    @classmethod
    def _format_comparison(cls, expr: ComparisonExpression) -> str:
        left = cls._format_helper(expr.left)
        right = cls._format_helper(expr.right)
        return f"""{left} {expr.operator.text} {right}"""

    @classmethod
    def _format_concatenation(cls, expr: ConcatenationExpression) -> str:
        left = cls._format_helper(expr.left)
        right = cls._format_helper(expr.right)
        return f"""{left} {expr.operator.text} {right}"""

    @classmethod
    def _format_div_mul(cls, expr: DivMulExpression) -> str:
        left = cls._format_helper(expr.left)
        right = cls._format_helper(expr.right)
        return f"""{left} {expr.operator.text} {right}"""

    @classmethod
    def _format_exponent(cls, expr: ExponentExpression) -> str:
        base = cls._format_helper(expr.base)
        power = cls._format_helper(expr.power)
        return f"""{base}^{power}"""

    @classmethod
    def _format_function(cls, expr: FunctionExpression) -> str:
        name = "".join(token.text for token in expr.name_parts)
        args = [cls._format_helper(arg) for arg in expr.args]
        if sum(len(x) for x in args) < MAX_ARGUMENT_LENGTH:
            arg_str = ", ".join(args)
            return f"{name}({arg_str})"
        arg_str = textwrap.indent(",\n".join(args), " " * 4)[4:]
        return f"""
{name}(
    {arg_str}
)""".strip()

    @classmethod
    def _format_hierarchy(cls, expr: HierarchyExpression) -> str:
        table = expr.table.text
        if table.startswith("'") and all(c in string.ascii_letters + string.digits + "_" for c in table[1:-1]):
            table = table[1:-1]
        return f"{table}{expr.column.text}.{expr.level.text}"

    @classmethod
    def _format_identifier(cls, expr: IdentifierExpression) -> str:
        return expr.name.text

    @classmethod
    def _format_in(cls, expr: InExpression) -> str:
        value = cls._format_helper(expr.value)
        array = cls._format_helper(expr.array)
        return f"""{value} IN {array}"""

    @classmethod
    def _format_keyword(cls, expr: KeywordExpression) -> str:
        return expr.name.text

    @classmethod
    def _format_literal_number(cls, expr: LiteralNumberExpression) -> str:
        return expr.value.text

    @classmethod
    def _format_literal_string(cls, expr: LiteralStringExpression) -> str:
        return expr.value.text

    @classmethod
    def _format_logical(cls, expr: LogicalExpression) -> str:
        left = cls._format_helper(expr.left)
        right = cls._format_helper(expr.right)
        return f"""{left} {expr.operator.text} {right}"""

    @classmethod
    def _format_measure(cls, expr: MeasureExpression) -> str:
        return expr.name.text

    @classmethod
    def _format_parens(cls, expr: ParenthesesExpression) -> str:
        inner = cls._format_helper(expr.inner_statement)
        return f"({inner})"

    @classmethod
    def _format_return(cls, expr: ReturnExpression) -> str:
        variable_strs = "\n".join(cls._format_helper(var) for var in expr.variable_statements)
        return_statement: str = cls._format_helper(expr.ret)
        return f"""
{variable_strs}
RETURN {return_statement}
""".strip()

    @classmethod
    def _format_table(cls, expr: TableExpression) -> str:
        table_name = expr.name.text
        if table_name.startswith("'") and all(
            c in string.ascii_letters + string.digits + "_" for c in table_name[1:-1]
        ):
            table_name = table_name[1:-1]
        return table_name

    @classmethod
    def _format_variable(cls, expr: VariableExpression) -> str:
        return f"{expr.var_name.text} = {cls._format_helper(expr.statement)}"