Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] MLflowVisBackend crashed on after_run hook #1574

Open
2 tasks done
maxvandenhoven opened this issue Sep 19, 2024 · 1 comment
Open
2 tasks done

[Bug] MLflowVisBackend crashed on after_run hook #1574

maxvandenhoven opened this issue Sep 19, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@maxvandenhoven
Copy link

Prerequisite

Environment

OrderedDict([('sys.platform', 'linux'), ('Python', '3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0]'), ('CUDA available', True), ('MUSA available', False), ('numpy_random_seed', 2147483648), ('GPU 0', 'NVIDIA GeForce RTX 3070 Laptop GPU'), ('CUDA_HOME', '/usr/local/cuda'), ('NVCC', 'Cuda compilation tools, release 11.8, V11.8.89'), ('GCC', 'gcc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0'), ('PyTorch', '2.0.0'), ('PyTorch compiling details', 'PyTorch built with:\n - GCC 9.3\n - C++ Version: 201703\n - Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications\n - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)\n - OpenMP 201511 (a.k.a. OpenMP 4.5)\n - LAPACK is enabled (usually provided by MKL)\n - NNPACK is enabled\n - CPU capability usage: AVX2\n - CUDA Runtime 11.8\n - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_37,code=compute_37\n - CuDNN 8.7\n - Magma 2.6.1\n - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, \n'), ('TorchVision', '0.15.0'), ('OpenCV', '4.8.1'), ('MMEngine', '0.10.3')])

Reproduces the problem - code sample

import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from mmengine.model import BaseModel
from mmengine.evaluator import BaseMetric
from mmengine.registry import MODELS, DATASETS, METRICS


@MODELS.register_module()
class MyAwesomeModel(BaseModel):
    def __init__(self, layers=4, activation='relu') -> None:
        super().__init__()
        if activation == 'relu':
            act_type = nn.ReLU
        elif activation == 'silu':
            act_type = nn.SiLU
        elif activation == 'none':
            act_type = nn.Identity
        else:
            raise NotImplementedError
        sequence = [nn.Linear(2, 64), act_type()]
        for _ in range(layers-1):
            sequence.extend([nn.Linear(64, 64), act_type()])
        self.mlp = nn.Sequential(*sequence)
        self.classifier = nn.Linear(64, 2)

    def forward(self, data, labels, mode):
        x = self.mlp(data)
        x = self.classifier(x)
        if mode == 'tensor':
            return x
        elif mode == 'predict':
            return F.softmax(x, dim=1), labels
        elif mode == 'loss':
            return {
                'loss_train_loss1': F.cross_entropy(x, labels),
                'loss_train_loss2': F.cross_entropy(x, labels),
            }


@DATASETS.register_module()
class MyDataset(Dataset):
    def __init__(self, is_train, size):
        self.is_train = is_train
        if self.is_train:
            torch.manual_seed(0)
            self.labels = torch.randint(0, 2, (size,))
        else:
            torch.manual_seed(3407)
            self.labels = torch.randint(0, 2, (size,))
        r = 3 * (self.labels+1) + torch.randn(self.labels.shape)
        theta = torch.rand(self.labels.shape) * 2 * torch.pi
        self.data = torch.vstack([r*torch.cos(theta), r*torch.sin(theta)]).T

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

    def __len__(self):
        return len(self.data)


@METRICS.register_module()
class Accuracy(BaseMetric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def process(self, data_batch, data_samples):
        score, gt = data_samples
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(r['correct'] for r in results)
        total_size = sum(r['batch_size'] for r in results)
        return dict(metric_train_accuracy=100*total_correct/total_size, metric_train_accuracy2=100*total_correct/total_size)
    

from torch.utils.data import DataLoader, default_collate
from torch.optim import Adam
from mmengine.runner import Runner


runner = Runner(
    # your model
    model=MyAwesomeModel(
        layers=2,
        activation='relu'),
    # work directory for saving checkpoints and logs
    work_dir='exp/my_awesome_model',

    # training data
    train_dataloader=DataLoader(
        dataset=MyDataset(
            is_train=True,
            size=10000),
        shuffle=True,
        collate_fn=default_collate,
        batch_size=64,
        pin_memory=True,
        num_workers=2),
    # training configurations
    train_cfg=dict(
        by_epoch=True,   # display in epoch number instead of iterations
        max_epochs=3,
        val_begin=1,     # start validation from the 2nd epoch
        val_interval=1), # do validation every 1 epoch

    # OptimizerWrapper, new concept in MMEngine for richer optimization options
    # Default value works fine for most cases. You may check our documentations
    # for more details, e.g. 'AmpOptimWrapper' for enabling mixed precision
    # training.
    optim_wrapper=dict(
        optimizer=dict(
            type=Adam,
            lr=0.001)),
    # ParamScheduler to adjust learning rates or momentums during training
    param_scheduler=dict(
        type='MultiStepLR',
        by_epoch=True,
        milestones=[4, 8],
        gamma=0.1),

    # validation data
    val_dataloader=DataLoader(
        dataset=MyDataset(
            is_train=False,
            size=1000),
        shuffle=False,
        collate_fn=default_collate,
        batch_size=1000,
        pin_memory=True,
        num_workers=2),
    # validation configurations, usually leave it an empty dict
    val_cfg=dict(),
    # evaluation metrics and evaluator
    val_evaluator=dict(type=Accuracy),

    # following are advanced configurations, try to default when not in need
    # hooks are advanced usage, try to default when not in need
    default_hooks=dict(
        # the most commonly used hook for modifying checkpoint saving interval
        checkpoint=dict(type='CheckpointHook', interval=1)),

    # `luancher` and `env_cfg` responsible for distributed environment
    launcher='none',
    env_cfg=dict(
        cudnn_benchmark=False,   # whether enable cudnn_benchmark
        backend='nccl',   # distributed communication backend
        mp_cfg=dict(mp_start_method='fork')),  # multiprocessing configs
    log_level='INFO',

    # load model weights from given path. None for no loading.
    load_from=None,
    # resume training from the given path
    resume=False,


    visualizer=dict(
        type="Visualizer",
        vis_backends=[
            dict(
                type="MLflowVisBackend",
                tracking_uri=os.getenv("MLFLOW_TRACKING_URI"),
            )
        ]
    )
)

# start training your model
runner.train()

Reproduces the problem - command or script

python test.py

Reproduces the problem - error message

Traceback (most recent call last):
  File "/workdir/test.py", line 176, in <module>
    runner.train()
  File "/opt/conda/envs/depth3dlane/lib/python3.10/site-packages/mmengine/runner/runner.py", line 1778, in train
    self.call_hook('after_run')
  File "/opt/conda/envs/depth3dlane/lib/python3.10/site-packages/mmengine/runner/runner.py", line 1839, in call_hook
    getattr(hook, fn_name)(self, **kwargs)
  File "/opt/conda/envs/depth3dlane/lib/python3.10/site-packages/mmengine/hooks/logger_hook.py", line 325, in after_run
    runner.visualizer.close()
  File "/opt/conda/envs/depth3dlane/lib/python3.10/site-packages/mmengine/visualization/visualizer.py", line 1150, in close
    vis_backend.close()
  File "/opt/conda/envs/depth3dlane/lib/python3.10/site-packages/mmengine/visualization/vis_backend.py", line 820, in close
    for filename in scandir(self.cfg.work_dir, self._artifact_suffix,
AttributeError: 'MLflowVisBackend' object has no attribute 'cfg'

Additional information

I am trying out MMEngine with MLFlow as a logging backend. I have an MLFLow server running on port 5000, which is stored in the MLFLOW_TRACKING_URI environment variable. Training works as expected, and I can see logs showing up, but at the end of the run the logger is closed incorrectly. It seems as if the runner does not automatically call the add_config method, which is supposed to set the cfg attribute on the MLflowVisBackend class. As it is my first time using MMEngine, I am wondering how to fix this issue, as I suspect I need to call add_config somewhere myself. Furthermore, I am wondering how to use the other artifact logging methods with the runner, as I am not sure where to include them in my model.

@maxvandenhoven maxvandenhoven added the bug Something isn't working label Sep 19, 2024
@maxvandenhoven
Copy link
Author

Update: after looking at the runner source code, it appears that the add_config method on the visualizer is only called when the runner has a cfg attribute set. Therefore, the MLflowVisBackend only seems to work when the runner is constructed from a config file, which I managed to verify using the config from the runner tutorial. I'm not sure whether this is the intended behavior for the MLflowVisBackend, but I would be willing to contribute to making it more reliable when constructing the runner from its arguments. However, this would be my first open source contribution, so some guidance would be really helpful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant