概要
上記の記事を参考にHydra+MLflowで機械学習を行っていた。
ここで、argparse
の項目を全てHydra
に変更したところ、グリッドサーチを行うときにやって欲しくない組み合わせに遭遇した。
そこで、特定の組み合わせはPythonのクラスのメンバー変数でまとめることにした。
開発環境
- python 3.6
- Hydra 1.0.3
具体例
フォルダ構成
/ ┣ train.py # config.ymalを読み込むファイル ┣ conf ┃ ┗ config.yaml ┣ dataset_1 ┃ ┣ trainファイル ┃ ┣ devファイル ┃ ┗ testファイル ┗ dataset_2 ┣ trainファイル ┣ devファイル ┗ testファイル
config.yaml
# @package _global_ dataset: train: dataset_1/trainファイル dev: dataset_1/devファイル test: dataset_1/testファイル
実行
上記のようなフォルダ構成でtrainファイル、devファイル、testファイルを学習・評価用のデータとして用意して、読み込んで学習させる。
python train.py -m \ dataset.train=dataset_1/trainファイル,dataset_2/trainファイル \ dataset.dev=dataset_1/devファイル,dataset_2/devファイル \ dataset.test=dataset_1/testファイル,dataset_2/testファイル
と実行。
このままではグリッドサーチでdataset_1
、dataset_2
のtrainファイル、devファイル、testファイルで8通りの学習をさせてしまう。
解決策
公式を参考
フォルダ構成
/ ┣ train.py # config.ymalを読み込むファイル ┣ util.py # dataset_1.yaml、dataset_2.yamlで読み込むファイル ┣ conf ┃ ┣ dataset # 追加 ┃ ┃ ┣ dataset_1.yaml ┃ ┃ ┗ dataset_2.yaml ┃ ┗ config.yaml ┣ dataset_1 ┃ ┣ trainファイル ┃ ┣ devファイル ┃ ┗ testファイル ┗ dataset_2 ┣ trainファイル ┣ devファイル ┗ testファイル
util.py
class Dataset(object): def __init__(self, train, dev, test): self.train = train self.dev = dev self.test = test ・・・
config.yaml
# @package _global_ defaults: - dataset: dataset_1
dataset
dataset_1.yaml
# @package _group_ _target_: util.Dataset train: dataset_1/trainファイル dev: dataset_1/devファイル test: dataset_1/testファイル
dataset_2.yaml
# @package _group_ _target_: util.Dataset train: dataset_2/trainファイル dev: dataset_2/devファイル test: dataset_2/testファイル
train.py
・・・ dataset = hydra.utils.instantiate(cfg.dataset) ・・・
実行
上記のようなフォルダ構成・コードに書き加えて
python train.py -m dataset=dataset_1,dataset_2
を実行