# type: ignore
import collections
import sys
import typing

try:
    from typing_extensions import _AnnotatedAlias
except Exception:
    try:
        from typing import _AnnotatedAlias
    except Exception:
        _AnnotatedAlias = None

try:
    from typing_extensions import get_type_hints as _get_type_hints
except Exception:
    from typing import get_type_hints as _get_type_hints

try:
    from typing_extensions import NotRequired, Required
except Exception:
    try:
        from typing import NotRequired, Required
    except Exception:
        Required = NotRequired = None


if Required is None and _AnnotatedAlias is None:
    # No extras available, so no `include_extras`
    get_type_hints = _get_type_hints
else:

    def get_type_hints(obj):
        return _get_type_hints(obj, include_extras=True)


# The `is_class` argument was new in 3.11, but was backported to 3.9 and 3.10.
# It's _likely_ to be available for 3.9/3.10, but may not be. Easiest way to
# check is to try it and see. This check can be removed when we drop support
# for Python 3.10.
try:
    typing.ForwardRef("Foo", is_class=True)
except TypeError:

    def _forward_ref(value):
        return typing.ForwardRef(value, is_argument=False)

else:

    def _forward_ref(value):
        return typing.ForwardRef(value, is_argument=False, is_class=True)


def _apply_params(obj, mapping):
    if params := getattr(obj, "__parameters__", None):
        args = tuple(mapping.get(p, p) for p in params)
        return obj[args]
    elif isinstance(obj, typing.TypeVar):
        return mapping.get(obj, obj)
    return obj


def _get_class_mro_and_typevar_mappings(obj):
    mapping = {}

    if isinstance(obj, type):
        cls = obj
    else:
        cls = obj.__origin__

    def inner(c, scope):
        if isinstance(c, type):
            cls = c
            new_scope = {}
        else:
            cls = getattr(c, "__origin__", None)
            if cls in (None, object, typing.Generic) or cls in mapping:
                return
            params = cls.__parameters__
            args = tuple(_apply_params(a, scope) for a in c.__args__)
            assert len(params) == len(args)
            mapping[cls] = new_scope = dict(zip(params, args))

        if issubclass(cls, typing.Generic):
            bases = getattr(cls, "__orig_bases__", cls.__bases__)
            for b in bases:
                inner(b, new_scope)

    inner(obj, {})
    return cls.__mro__, mapping


def get_class_annotations(obj):
    """Get the annotations for a class.

    This is similar to ``typing.get_type_hints``, except:

    - We maintain it
    - It leaves extras like ``Annotated``/``ClassVar`` alone
    - It resolves any parametrized generics in the class mro. The returned
      mapping may still include ``TypeVar`` values, but those should be treated
      as their unparametrized variants (i.e. equal to ``Any`` for the common case).

    Note that this function doesn't check that Generic types are being used
    properly - invalid uses of `Generic` may slip through without complaint.

    The assumption here is that the user is making use of a static analysis
    tool like ``mypy``/``pyright`` already, which would catch misuse of these
    APIs.
    """
    hints = {}
    mro, typevar_mappings = _get_class_mro_and_typevar_mappings(obj)

    for cls in mro:
        if cls in (typing.Generic, object):
            continue

        mapping = typevar_mappings.get(cls)
        cls_locals = dict(vars(cls))
        cls_globals = getattr(sys.modules.get(cls.__module__, None), "__dict__", {})

        ann = cls.__dict__.get("__annotations__", {})
        for name, value in ann.items():
            if name in hints:
                continue
            if value is None:
                value = type(None)
            elif isinstance(value, str):
                value = _forward_ref(value)
            value = typing._eval_type(value, cls_locals, cls_globals)
            if mapping is not None:
                value = _apply_params(value, mapping)
            hints[name] = value
    return hints


# A mapping from a type annotation (or annotation __origin__) to the concrete
# python type that msgspec will use when decoding. THIS IS PRIVATE FOR A
# REASON. DON'T MUCK WITH THIS.
_CONCRETE_TYPES = {
    list: list,
    tuple: tuple,
    set: set,
    frozenset: frozenset,
    dict: dict,
    typing.List: list,
    typing.Tuple: tuple,
    typing.Set: set,
    typing.FrozenSet: frozenset,
    typing.Dict: dict,
    typing.Collection: list,
    typing.MutableSequence: list,
    typing.Sequence: list,
    typing.MutableMapping: dict,
    typing.Mapping: dict,
    typing.MutableSet: set,
    typing.AbstractSet: set,
    collections.abc.Collection: list,
    collections.abc.MutableSequence: list,
    collections.abc.Sequence: list,
    collections.abc.MutableSet: set,
    collections.abc.Set: set,
    collections.abc.MutableMapping: dict,
    collections.abc.Mapping: dict,
}


def get_typeddict_info(obj):
    if isinstance(obj, type):
        cls = obj
    else:
        cls = obj.__origin__

    raw_hints = get_class_annotations(obj)

    if hasattr(cls, "__required_keys__"):
        required = set(cls.__required_keys__)
    elif cls.__total__:
        required = set(raw_hints)
    else:
        required = set()

    # Both `typing.TypedDict` and `typing_extensions.TypedDict` have a bug
    # where `Required`/`NotRequired` aren't properly detected at runtime when
    # `__future__.annotations` is enabled, meaning the `__required_keys__`
    # isn't correct. This code block works around this issue by amending the
    # set of required keys as needed, while also stripping off any
    # `Required`/`NotRequired` wrappers.
    hints = {}
    for k, v in raw_hints.items():
        origin = getattr(v, "__origin__", False)
        if origin is Required:
            required.add(k)
            hints[k] = v.__args__[0]
        elif origin is NotRequired:
            required.discard(k)
            hints[k] = v.__args__[0]
        else:
            hints[k] = v
    return hints, required


def get_dataclass_info(obj):
    if isinstance(obj, type):
        cls = obj
    else:
        cls = obj.__origin__
    hints = get_class_annotations(obj)
    required = []
    optional = []
    defaults = []

    if hasattr(cls, "__dataclass_fields__"):
        from dataclasses import _FIELD, _FIELD_INITVAR, MISSING

        for field in cls.__dataclass_fields__.values():
            if field._field_type is not _FIELD:
                if field._field_type is _FIELD_INITVAR:
                    raise TypeError(
                        "dataclasses with `InitVar` fields are not supported"
                    )
                continue
            name = field.name
            typ = hints[name]
            if field.default is not MISSING:
                defaults.append(field.default)
                optional.append((name, typ, False))
            elif field.default_factory is not MISSING:
                defaults.append(field.default_factory)
                optional.append((name, typ, True))
            else:
                required.append((name, typ, False))

        required.extend(optional)

        pre_init = None
        post_init = getattr(cls, "__post_init__", None)
    else:
        from attrs import NOTHING, Factory

        fields_with_validators = []

        for field in cls.__attrs_attrs__:
            name = field.name
            typ = hints[name]
            default = field.default
            if default is not NOTHING:
                if isinstance(default, Factory):
                    if default.takes_self:
                        raise NotImplementedError(
                            "Support for default factories with `takes_self=True` "
                            "is not implemented. File a GitHub issue if you need "
                            "this feature!"
                        )
                    defaults.append(default.factory)
                    optional.append((name, typ, True))
                else:
                    defaults.append(default)
                    optional.append((name, typ, False))
            else:
                required.append((name, typ, False))

            if field.validator is not None:
                fields_with_validators.append(field)

        required.extend(optional)

        pre_init = getattr(cls, "__attrs_pre_init__", None)
        post_init = getattr(cls, "__attrs_post_init__", None)

        if fields_with_validators:
            post_init = _wrap_attrs_validators(fields_with_validators, post_init)

    return cls, tuple(required), tuple(defaults), pre_init, post_init


def _wrap_attrs_validators(fields, post_init):
    def inner(obj):
        for field in fields:
            field.validator(obj, field, getattr(obj, field.name))
        if post_init is not None:
            post_init(obj)

    return inner


def rebuild(cls, kwargs):
    """Used to unpickle Structs with keyword-only fields"""
    return cls(**kwargs)
