Python의 Hook 기능을 빠르게 익히세요

coldplay.xixi
풀어 주다: 2020-12-11 17:12:06
앞으로
8364명이 탐색했습니다.

Python tutorial 칼럼에서는 Python의 Hook Hook 기능을 소개합니다

Python의 Hook 기능을 빠르게 익히세요

많은 무료 학습 권장 사항이 있으니 python tutorial(동영상)

을 방문하세요. 1. 예 후크

훅 기능이라는 개념을 자주 듣습니다. 최근 타겟 감지 오픈 소스 프레임워크인 mmDetection을 살펴보았는데, 그 안에 Hook 프로그래밍 방식도 많이 들어있습니다. 그렇다면 Hook란 정확히 무엇일까요? 후크의 기능은 무엇입니까?

  • 훅이 뭐예요? 후크는 이름에서 알 수 있듯이 필요할 때 무언가를 걸어 두는 데 사용되는 후크로 이해될 수 있습니다. 구체적인 설명은 다음과 같습니다. 후크 기능은 우리가 구현한 후크 기능을 특정 순간에 대상 마운트 지점에 연결하는 것입니다.

  • 훅 기능의 역할 예를 들어, 후크 개념은 Windows 데스크톱 소프트웨어 개발에서 매우 일반적입니다. 특히 C++ MFC 프로그램과 같은 다양한 이벤트 트리거의 메커니즘에서는 시간을 모니터링해야 합니다. 마우스 왼쪽 버튼을 누르면 MFC는 onLeftKeyDown 후크 기능을 제공합니다. 분명히 MFC 프레임워크는 onLeftKeyDown의 특정 작업을 구현하지 않고 후크만 제공합니다. 이를 처리해야 할 경우 이 함수를 다시 작성하고 이 후크에 필요한 작업을 마운트하기만 하면 됩니다. 탑재하지 않으면 MFC 이벤트 트리거 메커니즘이 빈 작업을 수행합니다.

위에서 알 수 있듯이

  • hook 함수는 프로그램에 미리 정의된 함수입니다. 이 함수는 원래 프로그램 프로세스에 있습니다(hook 노출).

  • Hook을 정의해야 합니다. 기존 프로세스 함수 블록에서 특정 세부 사항을 구현하려면 후크 기능을 대상에서 사용할 수 있도록 후크에 구현을 연결하거나 등록해야 합니다. 후크는 프로그래밍 메커니즘이며 특정 언어와 직접적인 관련이 없습니다. 관계

  • 디자인 모드를 보면 훅 모드는 템플릿 방식의 확장

  • 훅은 등록할 때만 사용되기 때문에 원래 프로그램 과정에서는 등록이나 마운트가 없을 때 , 실행이 비어 있음(즉, 작업이 수행되지 않음)

  • 이 기사에서는 Python을 사용하여 Hooks 구현을 설명하고 오픈 소스 프로젝트에서 Hooks의 적용 사례를 보여줍니다. 후크 함수의 기능은 우리가 자주 듣는 또 다른 이름인 콜백 함수와 유사하며 동일한 모델에 따라 이해할 수 있습니다.

2. Hook 구현 예

Python의 Hook 기능을 빠르게 익히세요내가 아는 한 Hook 기능은 일종의 프로세스 처리에서 가장 일반적으로 사용됩니다. 이 프로세스에는 종종 여러 단계가 있습니다. 추가 작업을 추가하기 위한 유연성을 제공하기 위해 후크 기능이 이러한 단계에 마운트되는 경우가 많습니다.

다음은 간단한 예입니다. 이 예의 목적은 대기열에 콘텐츠를 삽입하는 보편적인 기능을 구현하는 것입니다. 2가지 프로세스 단계가 있습니다

input_filter_fn

  • 큐에 삽입하기 전에 데이터를 필터링해야 합니다. insert_queueinput_filter_fn

  • 插入队列 insert_queue

class ContentStash(object):
    """
    content stash for online operation
    pipeline is
    1. input_filter: filter some contents, no use to user
    2. insert_queue(redis or other broker): insert useful content to queue
    """

    def __init__(self):
        self.input_filter_fn = None
        self.broker = []

    def register_input_filter_hook(self, input_filter_fn):
        """
        register input filter function, parameter is content dict
        Args:
            input_filter_fn: input filter function

        Returns:

        """
        self.input_filter_fn = input_filter_fn

    def insert_queue(self, content):
        """
        insert content to queue
        Args:
            content: dict

        Returns:

        """
        self.broker.append(content)

    def input_pipeline(self, content, use=False):
        """
        pipeline of input for content stash
        Args:
            use: is use, defaul False
            content: dict

        Returns:

        """
        if not use:
            return

        # input filter
        if self.input_filter_fn:
            _filter = self.input_filter_fn(content)
            
        # insert to queue
        if not _filter:
            self.insert_queue(content)



# test
## 实现一个你所需要的钩子实现:比如如果content 包含time就过滤掉,否则插入队列
def input_filter_hook(content):
    """
    test input filter hook
    Args:
        content: dict

    Returns: None or content

    """
    if content.get('time') is None:
        return
    else:
        return content


# 原有程序
content = {'filename': 'test.jpg', 'b64_file': "#test", 'data': {"result": "cat", "probility": 0.9}}
content_stash = ContentStash('audit', work_dir='')

# 挂上钩子函数, 可以有各种不同钩子函数的实现,但是要主要函数输入输出必须保持原有程序中一致,比如这里是content
content_stash.register_input_filter_hook(input_filter_hook)

# 执行流程
content_stash.input_pipeline(content)
로그인 후 복사

3. hook在开源框架中的应用

3.1 keras

在深度学习训练流程中,hook函数体现的淋漓尽致。

一个训练过程(不包括数据准备),会轮询多次训练集,每次称为一个epoch,每个epoch又分为多个batch来训练。流程先后拆解成:

  • 开始训练

  • 训练一个epoch前

  • 训练一个batch前

  • 训练一个batch后

  • 训练一个epoch后

  • 评估验证集

  • 结束训练

这些步骤是穿插在训练一个batch数据的过程中,这些可以理解成是钩子函数,我们可能需要在这些钩子函数中实现一些定制化的东西,比如在训练一个epoch后我们要保存下训练的模型,在结束训练时用最好的模型执行下测试集的效果等等。

keras中是通过各种回调函数来实现钩子hook功能的。这里放一个callback的父类,定制时只要继承这个父类,实现你过关注的钩子就可以了。

@keras_export('keras.callbacks.Callback')
class Callback(object):
  """Abstract base class used to build new callbacks.

  Attributes:
      params: Dict. Training parameters
          (eg. verbosity, batch size, number of epochs...).
      model: Instance of `keras.models.Model`.
          Reference of the model being trained.

  The `logs` dictionary that callback methods
  take as argument will contain keys for quantities relevant to
  the current batch or epoch (see method-specific docstrings).
  """

  def __init__(self):
    self.validation_data = None  # pylint: disable=g-missing-from-attributes
    self.model = None
    # Whether this Callback should only run on the chief worker in a
    # Multi-Worker setting.
    # TODO(omalleyt): Make this attr public once solution is stable.
    self._chief_worker_only = None
    self._supports_tf_logs = False

  def set_params(self, params):
    self.params = params

  def set_model(self, model):
    self.model = model

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_batch_begin(self, batch, logs=None):
    """A backwards compatibility alias for `on_train_batch_begin`."""

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_batch_end(self, batch, logs=None):
    """A backwards compatibility alias for `on_train_batch_end`."""

  @doc_controls.for_subclass_implementers
  def on_epoch_begin(self, epoch, logs=None):
    """Called at the start of an epoch.

    Subclasses should override for any actions to run. This function should only
    be called during TRAIN mode.

    Arguments:
        epoch: Integer, index of epoch.
        logs: Dict. Currently no data is passed to this argument for this method
          but that may change in the future.
    """

  @doc_controls.for_subclass_implementers
  def on_epoch_end(self, epoch, logs=None):
    """Called at the end of an epoch.

    Subclasses should override for any actions to run. This function should only
    be called during TRAIN mode.

    Arguments:
        epoch: Integer, index of epoch.
        logs: Dict, metric results for this training epoch, and for the
          validation epoch if validation is performed. Validation result keys
          are prefixed with `val_`.
    """

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_train_batch_begin(self, batch, logs=None):
    """Called at the beginning of a training batch in `fit` methods.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict, contains the return value of `model.train_step`. Typically,
          the values of the `Model`'s metrics are returned.  Example:
          `{'loss': 0.2, 'accuracy': 0.7}`.
    """
    # For backwards compatibility.
    self.on_batch_begin(batch, logs=logs)

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_train_batch_end(self, batch, logs=None):
    """Called at the end of a training batch in `fit` methods.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict. Aggregated metric results up until this batch.
    """
    # For backwards compatibility.
    self.on_batch_end(batch, logs=logs)

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_test_batch_begin(self, batch, logs=None):
    """Called at the beginning of a batch in `evaluate` methods.

    Also called at the beginning of a validation batch in the `fit`
    methods, if validation data is provided.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict, contains the return value of `model.test_step`. Typically,
          the values of the `Model`'s metrics are returned.  Example:
          `{'loss': 0.2, 'accuracy': 0.7}`.
    """

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_test_batch_end(self, batch, logs=None):
    """Called at the end of a batch in `evaluate` methods.

    Also called at the end of a validation batch in the `fit`
    methods, if validation data is provided.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict. Aggregated metric results up until this batch.
    """

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_predict_batch_begin(self, batch, logs=None):
    """Called at the beginning of a batch in `predict` methods.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict, contains the return value of `model.predict_step`,
          it typically returns a dict with a key 'outputs' containing
          the model's outputs.
    """

  @doc_controls.for_subclass_implementers
  @generic_utils.default
  def on_predict_batch_end(self, batch, logs=None):
    """Called at the end of a batch in `predict` methods.

    Subclasses should override for any actions to run.

    Arguments:
        batch: Integer, index of batch within the current epoch.
        logs: Dict. Aggregated metric results up until this batch.
    """

  @doc_controls.for_subclass_implementers
  def on_train_begin(self, logs=None):
    """Called at the beginning of training.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently no data is passed to this argument for this method
          but that may change in the future.
    """

  @doc_controls.for_subclass_implementers
  def on_train_end(self, logs=None):
    """Called at the end of training.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently the output of the last call to `on_epoch_end()`
          is passed to this argument for this method but that may change in
          the future.
    """

  @doc_controls.for_subclass_implementers
  def on_test_begin(self, logs=None):
    """Called at the beginning of evaluation or validation.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently no data is passed to this argument for this method
          but that may change in the future.
    """

  @doc_controls.for_subclass_implementers
  def on_test_end(self, logs=None):
    """Called at the end of evaluation or validation.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently the output of the last call to
          `on_test_batch_end()` is passed to this argument for this method
          but that may change in the future.
    """

  @doc_controls.for_subclass_implementers
  def on_predict_begin(self, logs=None):
    """Called at the beginning of prediction.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently no data is passed to this argument for this method
          but that may change in the future.
    """

  @doc_controls.for_subclass_implementers
  def on_predict_end(self, logs=None):
    """Called at the end of prediction.

    Subclasses should override for any actions to run.

    Arguments:
        logs: Dict. Currently no data is passed to this argument for this method
          but that may change in the future.
    """

  def _implements_train_batch_hooks(self):
    """Determines if this Callback should be called for each train batch."""
    return (not generic_utils.is_default(self.on_batch_begin) or
            not generic_utils.is_default(self.on_batch_end) or
            not generic_utils.is_default(self.on_train_batch_begin) or
            not generic_utils.is_default(self.on_train_batch_end))
로그인 후 복사

这些钩子的原始程序是在模型训练流程中的

keras源码位置: tensorflowpythonkerasenginetraining.py

部分摘录如下(## I am hook):

# Container that configures and calls `tf.keras.Callback`s.
      if not isinstance(callbacks, callbacks_module.CallbackList):
        callbacks = callbacks_module.CallbackList(
            callbacks,
            add_history=True,
            add_progbar=verbose != 0,
            model=self,
            verbose=verbose,
            epochs=epochs,
            steps=data_handler.inferred_steps)

      ## I am hook
      callbacks.on_train_begin()
      training_logs = None
      # Handle fault-tolerance for multi-worker.
      # TODO(omalleyt): Fix the ordering issues that mean this has to
      # happen after `callbacks.on_train_begin`.
      data_handler._initial_epoch = (  # pylint: disable=protected-access
          self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
      for epoch, iterator in data_handler.enumerate_epochs():
        self.reset_metrics()
        callbacks.on_epoch_begin(epoch)
        with data_handler.catch_stop_iteration():
          for step in data_handler.steps():
            with trace.Trace(
                'TraceContext',
                graph_type='train',
                epoch_num=epoch,
                step_num=step,
                batch_size=batch_size):
              ## I am hook
              callbacks.on_train_batch_begin(step)
              tmp_logs = train_function(iterator)
              if data_handler.should_sync:
                context.async_wait()
              logs = tmp_logs  # No error, now safe to assign to logs.
              end_step = step + data_handler.step_increment
              callbacks.on_train_batch_end(end_step, logs)
        epoch_logs = copy.copy(logs)

        # Run validation.

        ## I am hook
        callbacks.on_epoch_end(epoch, epoch_logs)
로그인 후 복사

3.2 mmdetection

mmdetection是一个目标检测的开源框架,集成了许多不同的目标检测深度学习算法(pytorch版),如faster-rcnn, fpn, retianet等。里面也大量使用了hook,暴露给应用实现流程中具体部分。

详见https://github.com/open-mmlab/mmdetection

🎜
def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):
    logger = get_root_logger(cfg.log_level)

    # prepare data loaders

    # put model on gpus

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = EpochBasedRunner(
        model,
        optimizer=optimizer,
        work_dir=cfg.work_dir,
        logger=logger,
        meta=meta)
    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get('momentum_config', None))
    if distributed:
        runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        # Support batch_size > 1 in validation
        eval_cfg = cfg.get('evaluation', {})
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

    # user-defined hooks
    if cfg.get('custom_hooks', None):
        custom_hooks = cfg.custom_hooks
        assert isinstance(custom_hooks, list), \
            f'custom_hooks expect list type, but got {type(custom_hooks)}'
        for hook_cfg in cfg.custom_hooks:
            assert isinstance(hook_cfg, dict), \
                'Each item in custom_hooks expects dict type, but got ' \
                f'{type(hook_cfg)}'
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop('priority', 'NORMAL')
            hook = build_from_cfg(hook_cfg, HOOKS)
            runner.register_hook(hook, priority=priority)
로그인 후 복사
로그인 후 복사
🎜3 .The Hook은 오픈 소스 프레임워크의 애플리케이션🎜

3.1 keras

🎜딥 러닝 훈련 과정에서 Hook 기능이 완전히 반영됩니다. 🎜🎜훈련 프로세스(데이터 준비 제외)는 훈련 세트를 여러 번 폴링하고, 각 시간을 에포크라고 하며, 각 에포크는 훈련을 위해 여러 배치로 나뉩니다. 프로세스는 다음과 같이 분류됩니다. 🎜🎜🎜🎜훈련 시작🎜🎜🎜🎜한 에포크 훈련 전🎜🎜🎜🎜배치 훈련 전🎜🎜🎜🎜배치 훈련 후🎜🎜🎜🎜한 에포크 훈련 후🎜🎜 🎜🎜평가 검증 세트🎜🎜🎜🎜훈련 종료🎜🎜🎜🎜이러한 단계는 일련의 데이터를 훈련하는 과정에 산재해 있습니다. 이는 후크 기능으로 이해될 수 있습니다. < code>특정 에포크 동안 학습한 후 학습된 모델을 저장해야 하며, 학습을 종료할 때 가장 적합한 모델을 사용하여 테스트 세트 효과 등을 수행합니다. 🎜🎜훅 기능은 다양한 콜백 기능을 통해 케라스에서 구현됩니다. 여기에 콜백의 상위 클래스를 넣습니다. 커스터마이징할 때 이 상위 클래스를 상속하고 관심 있는 후크만 구현하면 됩니다. 🎜rrreee🎜이 후크의 원래 프로그램은 모델 훈련 프로세스에 있습니다🎜🎜🎜keras 소스 코드 위치: tensorflowpythonkerasenginetraining.py🎜🎜🎜발췌 부분은 다음과 같습니다(## 나는 후크입니다):🎜rrreee

3.2 mmDetection

🎜mmDetection은 fast-rcnn, fpn, retianet 등과 같은 다양한 표적 탐지 딥 러닝 알고리즘(pytorch 버전)을 통합하는 표적 탐지를 위한 오픈 소스 프레임워크입니다. 후크는 또한 애플리케이션 구현 프로세스의 특정 부분을 노출하기 위해 광범위하게 사용됩니다. 🎜🎜자세한 내용은 https://github.com/open-mmlab/mmDetection을 참조하세요🎜

这里看一个训练的调用例子(摘录)(https://github.com/open-mmlab/mmdetection/blob/5d592154cca589c5113e8aadc8798bbc73630d98/mmdet/apis/train.py

def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   timestamp=None,
                   meta=None):
    logger = get_root_logger(cfg.log_level)

    # prepare data loaders

    # put model on gpus

    # build runner
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = EpochBasedRunner(
        model,
        optimizer=optimizer,
        work_dir=cfg.work_dir,
        logger=logger,
        meta=meta)
    # an ugly workaround to make .log and .log.json filenames the same
    runner.timestamp = timestamp

    # fp16 setting
    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config,
                                   cfg.get(&#39;momentum_config&#39;, None))
    if distributed:
        runner.register_hook(DistSamplerSeedHook())

    # register eval hooks
    if validate:
        # Support batch_size > 1 in validation
        eval_cfg = cfg.get(&#39;evaluation&#39;, {})
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

    # user-defined hooks
    if cfg.get(&#39;custom_hooks&#39;, None):
        custom_hooks = cfg.custom_hooks
        assert isinstance(custom_hooks, list), \
            f&#39;custom_hooks expect list type, but got {type(custom_hooks)}&#39;
        for hook_cfg in cfg.custom_hooks:
            assert isinstance(hook_cfg, dict), \
                &#39;Each item in custom_hooks expects dict type, but got &#39; \
                f&#39;{type(hook_cfg)}&#39;
            hook_cfg = hook_cfg.copy()
            priority = hook_cfg.pop(&#39;priority&#39;, &#39;NORMAL&#39;)
            hook = build_from_cfg(hook_cfg, HOOKS)
            runner.register_hook(hook, priority=priority)
로그인 후 복사
로그인 후 복사

4. 总结

本文介绍了hook的概念和应用,并给出了python的实现细则。希望对比有帮助。总结如下:

  • hook函数是流程中预定义好的一个步骤,没有实现

  • 挂载或者注册时, 流程执行就会执行这个钩子函数

  • 回调函数和hook函数功能上是一致的

  • hook设计方式带来灵活性,如果流程中有一个步骤,你想让调用方来实现,你可以用hook函数

相关免费学习推荐:php编程(视频)

위 내용은 Python의 Hook 기능을 빠르게 익히세요의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

관련 라벨:
원천:csdn.net
본 웹사이트의 성명
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.
인기 튜토리얼
더>
최신 다운로드
더>
웹 효과
웹사이트 소스 코드
웹사이트 자료
프론트엔드 템플릿