Hydraで特定のクラスを読み込む

概要

ymym3412.hatenablog.com

qiita.com

上記の記事を参考にHydra+MLflowで機械学習を行っていた。
ここで、argparseの項目を全てHydraに変更したところ、グリッドサーチを行うときにやって欲しくない組み合わせに遭遇した。
そこで、特定の組み合わせはPythonのクラスのメンバー変数でまとめることにした。

開発環境

具体例

フォルダ構成

/
┣ train.py  # config.ymalを読み込むファイル
┣ conf
┃  ┗ config.yaml
┣ dataset_1
┃  ┣ trainファイル
┃  ┣ devファイル
┃  ┗ testファイル
┗ dataset_2
    ┣ trainファイル
    ┣ devファイル
    ┗ testファイル

config.yaml

# @package _global_ 
dataset:
      train: dataset_1/trainファイル
      dev: dataset_1/devファイル
      test: dataset_1/testファイル

実行

上記のようなフォルダ構成でtrainファイル、devファイル、testファイルを学習・評価用のデータとして用意して、読み込んで学習させる。

python train.py -m \
dataset.train=dataset_1/trainファイル,dataset_2/trainファイル \
dataset.dev=dataset_1/devファイル,dataset_2/devファイル \
dataset.test=dataset_1/testファイル,dataset_2/testファイル

と実行。
このままではグリッドサーチでdataset_1dataset_2のtrainファイル、devファイル、testファイルで8通りの学習をさせてしまう。

解決策

hydra.cc

公式を参考

フォルダ構成

/
┣ train.py  # config.ymalを読み込むファイル
┣ util.py  # dataset_1.yaml、dataset_2.yamlで読み込むファイル
┣ conf
┃  ┣ dataset  # 追加
┃  ┃  ┣ dataset_1.yaml
┃  ┃  ┗ dataset_2.yaml
┃  ┗ config.yaml
┣ dataset_1
┃  ┣ trainファイル
┃  ┣ devファイル
┃  ┗ testファイル
┗ dataset_2
    ┣ trainファイル
    ┣ devファイル
    ┗ testファイル

util.py

class Dataset(object):
     def __init__(self, train, dev, test):
           self.train = train
           self.dev = dev
           self.test = test
・・・

config.yaml

# @package _global_ 
defaults:
      - dataset: dataset_1

dataset

dataset_1.yaml

# @package _group_
_target_: util.Dataset
train: dataset_1/trainファイル
dev: dataset_1/devファイル
test: dataset_1/testファイル

dataset_2.yaml

# @package _group_
_target_: util.Dataset
train: dataset_2/trainファイル
dev: dataset_2/devファイル
test: dataset_2/testファイル

train.py

・・・
dataset = hydra.utils.instantiate(cfg.dataset)
・・・

実行

上記のようなフォルダ構成・コードに書き加えて

python train.py -m dataset=dataset_1,dataset_2

を実行