Source code for clout._loaders.cli

import dataclasses
import functools
import os
import sys
import typing as t

import attr
import click
import desert
import glom
import inflection
import lark
import marshmallow
import typing_inspect

import clout.exceptions

from .. import _util
from . import parsing


NO_DEFAULT = "__NO_DEFAULT__"
TOP_LEVEL_NAME = "top_level_name"
NoneType = type(None)
Dataclass = t.NewType("Dataclass", type)


def is_python_syntax(s: str) -> bool:
    import black

    try:
        black.format_str(s, 88)
    except black.InvalidInput:
        return False
    return True


def pythonify(obj):
    if is_python_syntax(repr(obj)):
        return repr(obj)
    if callable(obj) and is_python_syntax(obj.__name__):
        return obj.__name__
    return "..."


class Debug:
    def __repr__(self):
        import black

        pairs = ", ".join(
            (
                f"{k}={pythonify(v)}"
                for k, v in vars(self).items()
                if not k.startswith("_")
            )
        )
        name = type(self).__name__
        return black.format_str(f"{name}({pairs})", line_length=88)


class Option(Debug, click.Option):
    pass


class Group(Debug, parsing.CountingGroup):
    def __call__(self, *args, **kwargs):
        try:
            return super().__call__(*args, **kwargs, standalone_mode=False)
        except SystemExit as e:
            if e.code != 0:
                raise


class DebuggableCommand(Debug, parsing.CountingCommand):
    def __call__(self, *args, **kwargs):
        try:
            return super().__call__(*args, **kwargs)
        except SystemExit as e:
            if e.code != 0:
                raise


class MarshmallowFieldParam(click.ParamType):
    def __init__(self, field):
        self.field = field

    @property
    def name(self):
        return type(self.field).__name__.split(".")[-1]

    def convert(self, value, param, ctx):

        try:
            return self.field.deserialize(value)
        except marshmallow.exceptions.ValidationError as e:
            raise click.BadParameter(ctx=ctx, param=param, message=str(e)) from e


MM_TO_CLICK = {
    marshmallow.fields.String: click.STRING,
    marshmallow.fields.Int: click.INT,
    marshmallow.fields.Float: click.FLOAT,
    marshmallow.fields.Raw: click.STRING,
    marshmallow.fields.Bool: click.BOOL,
}


def make_help_command():
    return DebuggableCommand(name="--help", hidden=True)


@functools.singledispatch
def make_param_from_field(
    field: marshmallow.fields.Field, data, default
) -> click.Parameter:

    if data:

        data = data.copy()
        if "type" not in data:
            data["type"] = MM_TO_CLICK[type(field)]

        return Option(show_default=True, **data)

    return Option(
        ["--" + field.name],
        type=MarshmallowFieldParam(field),
        required=False,
        default=default,
        show_default=True,
    )


@make_param_from_field.register(marshmallow.fields.Boolean)
def _(field: marshmallow.fields.Boolean, data, default) -> Option:

    if data:
        return Option(**data, default=default)

    field_name = _util.dasherize(field.name)
    required = field.missing == marshmallow.missing == field.default == default
    return Option(
        [f"--{field_name}/--no-{field_name}"],
        default=default,
        required=required,
        is_flag=True,
        show_default=not required,
    )


@make_param_from_field.register(marshmallow.fields.String)
@make_param_from_field.register(marshmallow.fields.Int)
@make_param_from_field.register(marshmallow.fields.Float)
@make_param_from_field.register(marshmallow.fields.Date)
@make_param_from_field.register(marshmallow.fields.DateTime)
@make_param_from_field.register(marshmallow.fields.Raw)
def _(field, data, default) -> Option:
    if data:
        data = data.copy()
        if "type" not in data:
            data["type"] = MM_TO_CLICK[type(field)]
        return Option(show_default=True, **data, default=default)
    param_type = MM_TO_CLICK[type(field)]
    return Option(
        ["--" + field.name], type=param_type, show_default=True, default=default
    )


def extract(mapping, path):
    for entry in path:
        mapping = mapping[entry]
    return mapping


def get_default(field, path, default_map):
    # XXX The envvar logic should be somewhere else.
    envvar = field.metadata.get("clout", {}).get("cli", {}).get("envvar")
    if envvar is not None:
        value = os.environ.get(envvar)
        if value is not None:
            return value

    try:
        value = extract(default_map, path[1:])
        return value
    except KeyError:
        return field.default


@attr.dataclass(frozen=True)
class CLI:
    context_settings: t.Dict[str, t.Any] = attr.ib(factory=dict)
    inherits: t.FrozenSet[str] = frozenset({"app_name"})
    metadata_key: str = "cli"
    args: t.List[str] = attr.ib(factory=list)
    app_name: t.Optional[str] = None

    def make_command_from_schema(
        self, schema: marshmallow.Schema, path: t.Sequence[str]
    ) -> click.BaseCommand:
        params = []
        commands = []

        for field in schema.fields.values():

            if isinstance(field, marshmallow.fields.Nested):
                commands.append(
                    self.make_command_from_schema(
                        field.schema, path=path + (field.name,)
                    )
                )
            elif isinstance(field, marshmallow.fields.Field):
                user_specified = field.metadata.get("cli")
                default = get_default(
                    field,
                    path=path + (field.name,),
                    default_map=self.context_settings.get("default_map", {}),
                )

                if isinstance(field.metadata.get("cli"), click.Parameter):
                    param = field.metadata["cli"]

                else:

                    param = make_param_from_field(
                        field, user_specified, default=default
                    )

                params.append(param)
            else:
                raise TypeError(field)

        commands.append(make_help_command())

        help = getattr(schema, "help", None)
        if commands:
            return Group(
                name=path[-1],
                commands={c.name: c for c in commands},
                params=params,
                chain=True,
                result_callback=lambda *a, **kw: (a, kw),
                help=help,
                short_help=help,
                context_settings=self.context_settings,
            )
        return DebuggableCommand(
            name=path[-1],
            params=params,
            callback=identity,
            help=help,
            short_help=help,
            context_settings=self.context_settings,
        )

    def get_command(
        self,
        typ: t.Type,
        default=NO_DEFAULT,
        metadata: t.Mapping[str, t.Any] = None,
        args=(),
    ):
        metadata = metadata or {}

        cli_metadata: t.Union[
            t.Dict[str, t.Any], click.BaseCommand, click.Parameter
        ] = metadata.get(self.metadata_key, None)
        if isinstance(cli_metadata, (click.BaseCommand, click.Parameter)):
            command = cli_metadata
        else:

            name = metadata.get("name", _util.dasherize(self.app_name))

            schema = desert.schema_class(typ)()
            command = self.make_command_from_schema(schema, path=(name,))

            def schema_load(*a, **kw):
                try:
                    return schema.load(*a, **kw)
                except marshmallow.exceptions.ValidationError as e:
                    raise clout.exceptions.ValidationError(*e.args) from e

            command.callback = schema_load

        command = Group(
            name=TOP_LEVEL_NAME,
            commands={c.name: c for c in [command, make_help_command()]},
            callback=command.callback,
        )
        return command

    def prep(
        self,
        typ: t.Type,
        default=NO_DEFAULT,
        metadata: t.Mapping[str, t.Any] = None,
        args=(),
    ):

        command = self.get_command(typ, default, metadata, args)

        parser = parsing.Parser(command, callback=command.callback, use_defaults=True)
        cli_args = (TOP_LEVEL_NAME, _util.dasherize(typ.__name__)) + tuple(
            args or self.args or sys.argv[2:]
        )

        try:
            result = parser.parse_args(cli_args)
        except click.exceptions.BadParameter as e:
            e.show()
            sys.exit(1)
        except (
            lark.exceptions.ParseError,
            lark.exceptions.UnexpectedCharacters,
            lark.exceptions.VisitError,
        ):

            print(command.get_help(click.Context(command)) + EPILOG)
            if int(os.environ.get("CLI_SHOW_TRACEBACK", 0)):
                raise
            else:
                sys.exit(1)
        except clout.exceptions.MissingInput as e:
            result = parser.parse_args(e.found + ["--help"])

        [value] = result.values()

        return value

    def build(
        self,
        typ: t.Type,
        default=NO_DEFAULT,
        metadata: t.Mapping[str, t.Any] = None,
        args=(),
    ):

        command = self.get_command(typ, default, metadata, args)

        result = command.callback(self.prep(typ, default, metadata, args))

        return result

    def set(self, **kw):
        return attr.evolve(self, **kw)


class NonStandaloneCommand(click.Command):
    def main(self, *a, standalone_mode=False, **kw):
        return super().main(*a, standalone_mode=standalone_mode, **kw)


EPILOG = "\n\nNote:\n  export CLI_SHOW_TRACEBACK=1 to show traceback on error.\n"


[docs]class Command(click.Command): """A :class:`click.Command` built from an :func:`attr.dataclass` or :func:`dataclasses.dataclass`.""" def __init__( self, type, *args, name=None, app_name=None, callback=lambda x: x, params=None, context_settings=None, epilog=None, **kwargs, ): if not (attr.has(type) or dataclasses.is_dataclass(type)): raise TypeError(f"Need a dataclass, got {type} of type {type.__class__}") self.app_name = app_name epilog = epilog or "" epilog += EPILOG context_settings = context_settings or {} context_settings["ignore_unknown_options"] = True super().__init__( name=name, *args, **kwargs, add_help_option=False, epilog=epilog, context_settings=context_settings, ) self.params = (params or []) + [ click.Argument(["args"], type=click.UNPROCESSED, nargs=-1) ] self.callback = lambda args: callback( CLI( app_name=app_name or _util.dasherize(type.__name__), context_settings=context_settings, ).build(type, args=args) )
[docs] def build(self): """Return an instance of `self.type`, built from the command line arguments.""" return self.main(standalone_mode=False)
[docs] def main(self, **kwargs): """Run the command and exit the program afterwards. Upcalls directly to :meth:`click.MultiCommand.main()`. """ return super().main(**kwargs)
def command(type: Dataclass, **kwargs): """A decorator that replaces the decorated function with a :class:`clout.Command`. The with `callback` attribute is set to the decorated function. Compare to :func:`click.command()`. """ def decorator(callback): return Command(type, callback=callback, **kwargs) return decorator