Source code for cosense3d.modules.plugin

# Copyright (c) OpenMMLab. All rights reserved. Modified by Yunshuang Yuan.
import inspect
from typing import Dict, Tuple, Union
from importlib import import_module

import torch.nn as nn
import re  # type: ignore


[docs]def infer_abbr(class_type: type) -> str: """Infer abbreviation from the class name. This method will infer the abbreviation to map class types to abbreviations. Rule 1: If the class has the property "abbr", return the property. Rule 2: Otherwise, the abbreviation falls back to snake case of class name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``. :param class_type: The norm layer type. :return: The inferred abbreviation. """ def camel2snack(word): """Convert camel case word into snack case. Modified from `inflection lib <https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_. Example:: >>> camel2snack("FancyBlock") 'fancy_block' """ word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word) word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word) word = word.replace('-', '_') return word.lower() if not inspect.isclass(class_type): raise TypeError( f'class_type must be a type, but got {type(class_type)}') if hasattr(class_type, '_abbr_'): return class_type._abbr_ # type: ignore else: return camel2snack(class_type.__name__)
[docs]def build_plugin_layer(cfg: Dict, postfix: Union[int, str] = '', **kwargs) -> Tuple[str, nn.Module]: """Build plugin layer. :param cfg: cfg should contain: - type (str): identify plugin layer type. - layer args: args needed to instantiate a plugin layer. :param postfix: appended into norm abbreviation to create named layer. Default: ''. :param kwargs: :return: The first one is the concatenation of abbreviation and postfix. The second is the created plugin layer. """ if not isinstance(cfg, dict): raise TypeError('cfg must be a dict') if 'type' not in cfg: raise KeyError('the cfg dict must contain the key "type"') cfg_ = cfg.copy() layer_type = cfg_.pop('type') try: pkg, cls = layer_type.rsplit('.', 1) plugin_layer = import_module(pkg).get(cls) except: raise KeyError(f'Unrecognized plugin type {layer_type}') abbr = infer_abbr(plugin_layer) assert isinstance(postfix, (int, str)) name = abbr + str(postfix) layer = plugin_layer(**kwargs, **cfg_) return name, layer
[docs]def build_plugin_module(cfg: Dict): cfg_ = cfg.copy() type_ = cfg_.pop('type') module_name, cls_name = type_.split('.') module = import_module(f'{__package__}.{module_name}') cls_inst = getattr(module, cls_name)(**cfg_) return cls_inst