diff --git a/dtrb/.gitignore b/dtrb/.gitignore
deleted file mode 100644
index 8b5757e..0000000
--- a/dtrb/.gitignore
+++ /dev/null
@@ -1,211 +0,0 @@
-**/saved_models/*
-**/data_lmdb_release/*
-**/image_release/*
-**/result/*
-*.mdb
-*.pth
-*.tar
-*.sh
-*.txt
-*.ipynb
-*.zip
-*.eps
-*.pdf
-dtrb/
-runs/
-dataset/
-*.mp4
-### Linux ###
-*~
-
-# temporary files which can be created if a process still has a handle open of a deleted file
-.fuse_hidden*
-
-# KDE directory preferences
-.directory
-
-# Linux trash folder which might appear on any partition or disk
-.Trash-*
-
-# .nfs files are created when an open file is removed but is still being accessed
-.nfs*
-
-### OSX ###
-# General
-.DS_Store
-.AppleDouble
-.LSOverride
-
-# Icon must end with two \r
-Icon
-
-# Thumbnails
-._*
-
-# Files that might appear in the root of a volume
-.DocumentRevisions-V100
-.fseventsd
-.Spotlight-V100
-.TemporaryItems
-.Trashes
-.VolumeIcon.icns
-.com.apple.timemachine.donotpresent
-
-# Directories potentially created on remote AFP share
-.AppleDB
-.AppleDesktop
-Network Trash Folder
-Temporary Items
-.apdisk
-
-### Python ###
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-.hypothesis/
-.pytest_cache/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-.python-version
-
-# celery beat schedule file
-celerybeat-schedule
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-### Python Patch ###
-.venv/
-
-### Python.VirtualEnv Stack ###
-# Virtualenv
-# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
-[Bb]in
-[Ii]nclude
-[Ll]ib
-[Ll]ib64
-[Ll]ocal
-[Ss]cripts
-pyvenv.cfg
-pip-selfcheck.json
-
-### Windows ###
-# Windows thumbnail cache files
-Thumbs.db
-ehthumbs.db
-ehthumbs_vista.db
-
-# Dump file
-*.stackdump
-
-# Folder config file
-[Dd]esktop.ini
-
-# Recycle Bin used on file shares
-$RECYCLE.BIN/
-
-# Windows Installer files
-*.cab
-*.msi
-*.msix
-*.msm
-*.msp
-
-# Windows shortcuts
-*.lnk
diff --git a/dtrb/LICENSE.md b/dtrb/LICENSE.md
deleted file mode 100644
index d645695..0000000
--- a/dtrb/LICENSE.md
+++ /dev/null
@@ -1,202 +0,0 @@
-
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/dtrb/OCRmodel.py b/dtrb/OCRmodel.py
deleted file mode 100644
index 220e732..0000000
--- a/dtrb/OCRmodel.py
+++ /dev/null
@@ -1,69 +0,0 @@
-"""
-Copyright (c) 2019-present NAVER Corp.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-"""
-
-import torch.nn as nn
-
-from modules.transformation import TPS_SpatialTransformerNetwork
-from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor
-from modules.sequence_modeling import BidirectionalLSTM
-from modules.prediction import Attention
-
-
-class Model(nn.Module):
-
- def __init__(self, opt):
- super(Model, self).__init__()
- self.opt = opt
- self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
- 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
-
-
- self.Transformation = TPS_SpatialTransformerNetwork(
- F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
-
- """ FeatureExtraction """
- self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
-
- self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
- self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
-
- """ Sequence modeling"""
-
- self.SequenceModeling = nn.Sequential(
- BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
- BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size))
- self.SequenceModeling_output = opt.hidden_size
-
- """ Prediction """
- self.Prediction = nn.Linear(self.SequenceModeling_output, 51)
-
- def forward(self, input, text, is_train=True):
- """ Transformation stage """
-
- input = self.Transformation(input)
-
- """ Feature extraction stage """
- visual_feature = self.FeatureExtraction(input)
- visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
- visual_feature = visual_feature.squeeze(3)
-
- """ Sequence modeling stage """
- contextual_feature = self.SequenceModeling(visual_feature)
- """ Prediction stage """
-
- prediction = self.Prediction(contextual_feature.contiguous())
-
- return prediction
\ No newline at end of file
diff --git a/dtrb/README.md b/dtrb/README.md
deleted file mode 100644
index 50de47c..0000000
--- a/dtrb/README.md
+++ /dev/null
@@ -1,192 +0,0 @@
-# What Is Wrong With Scene Text Recognition Model Comparisons? Dataset and Model Analysis
-| [paper](https://arxiv.org/abs/1904.01906) | [training and evaluation data](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) | [failure cases and cleansed label](https://github.com/clovaai/deep-text-recognition-benchmark#download-failure-cases-and-cleansed-label-from-here) | [pretrained model](https://www.dropbox.com/sh/j3xmli4di1zuv3s/AAArdcPgz7UFxIHUuKNOeKv_a?dl=0) | [Baidu ver(passwd:rryk)](https://pan.baidu.com/s/1KSNLv4EY3zFWHpBYlpFCBQ) |
-
-Official PyTorch implementation of our four-stage STR framework, that most existing STR models fit into.
-Using this framework allows for the module-wise contributions to performance in terms of accuracy, speed, and memory demand, under one consistent set of training and evaluation datasets.
-Such analyses clean up the hindrance on the current comparisons to understand the performance gain of the existing modules.
-
-
-## Honors
-Based on this framework, we recorded the 1st place of [ICDAR2013 focused scene text](https://rrc.cvc.uab.es/?ch=2&com=evaluation&task=3), [ICDAR2019 ArT](https://rrc.cvc.uab.es/files/ICDAR2019-ArT.pdf) and 3rd place of [ICDAR2017 COCO-Text](https://rrc.cvc.uab.es/?ch=5&com=evaluation&task=2), [ICDAR2019 ReCTS (task1)](https://rrc.cvc.uab.es/files/ICDAR2019-ReCTS.pdf).
-The difference between our paper and ICDAR challenge is summarized [here](https://github.com/clovaai/deep-text-recognition-benchmark/issues/13).
-
-## Updates
-**Aug 3, 2020**: added [guideline to use Baidu warpctc](https://github.com/clovaai/deep-text-recognition-benchmark/pull/209) which reproduces CTC results of our paper.
-**Dec 27, 2019**: added [FLOPS](https://github.com/clovaai/deep-text-recognition-benchmark/issues/125) in our paper, and minor updates such as log_dataset.txt and [ICDAR2019-NormalizedED](https://github.com/clovaai/deep-text-recognition-benchmark/blob/86451088248e0490ff8b5f74d33f7d014f6c249a/test.py#L139-L165).
-**Oct 22, 2019**: added [confidence score](https://github.com/clovaai/deep-text-recognition-benchmark/issues/82), and arranged the output form of training logs.
-**Jul 31, 2019**: The paper is accepted at International Conference on Computer Vision (ICCV), Seoul 2019, as an oral talk.
-**Jul 25, 2019**: The code for floating-point 16 calculation, check [@YacobBY's](https://github.com/YacobBY) [pull request](https://github.com/clovaai/deep-text-recognition-benchmark/pull/36)
-**Jul 16, 2019**: added [ST_spe.zip](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt) dataset, word images contain special characters in SynthText (ST) dataset, see [this issue](https://github.com/clovaai/deep-text-recognition-benchmark/issues/7#issuecomment-511727025)
-**Jun 24, 2019**: added gt.txt of failure cases that contains path and label of each image, see [image_release_190624.zip](https://drive.google.com/open?id=1VAP9l5GL5fgptgKDLio_h3nMe7X9W0Mf)
-**May 17, 2019**: uploaded resources in Baidu Netdisk also, added [Run demo](https://github.com/clovaai/deep-text-recognition-benchmark#run-demo-with-pretrained-model). (check [@sharavsambuu's](https://github.com/sharavsambuu) [colab demo also](https://colab.research.google.com/drive/1PHnc_QYyf9b1_KJ1r15wYXaOXkdm1Mrk))
-**May 9, 2019**: PyTorch version updated from 1.0.1 to 1.1.0, use torch.nn.CTCLoss instead of torch-baidu-ctc, and various minor updated.
-
-## Getting Started
-### Dependency
-- This work was tested with PyTorch 1.3.1, CUDA 10.1, python 3.6 and Ubuntu 16.04.
You may need `pip3 install torch==1.3.1`.
-In the paper, expriments were performed with **PyTorch 0.4.1, CUDA 9.0**.
-- requirements : lmdb, pillow, torchvision, nltk, natsort
-```
-pip3 install lmdb pillow torchvision nltk natsort
-```
-
-### Download lmdb dataset for traininig and evaluation from [here](https://www.dropbox.com/sh/i39abvnefllx2si/AAAbAYRvxzRp3cIE5HzqUw3ra?dl=0)
-data_lmdb_release.zip contains below.
-training datasets : [MJSynth (MJ)](http://www.robots.ox.ac.uk/~vgg/data/text/)[1] and [SynthText (ST)](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)[2] \
-validation datasets : the union of the training sets [IC13](http://rrc.cvc.uab.es/?ch=2)[3], [IC15](http://rrc.cvc.uab.es/?ch=4)[4], [IIIT](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html)[5], and [SVT](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset)[6].\
-evaluation datasets : benchmark evaluation datasets, consist of [IIIT](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html)[5], [SVT](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset)[6], [IC03](http://www.iapr-tc11.org/mediawiki/index.php/ICDAR_2003_Robust_Reading_Competitions)[7], [IC13](http://rrc.cvc.uab.es/?ch=2)[3], [IC15](http://rrc.cvc.uab.es/?ch=4)[4], [SVTP](http://openaccess.thecvf.com/content_iccv_2013/papers/Phan_Recognizing_Text_with_2013_ICCV_paper.pdf)[8], and [CUTE](http://cs-chan.com/downloads_CUTE80_dataset.html)[9].
-
-### Run demo with pretrained model
-1. Download pretrained model from [here](https://drive.google.com/drive/folders/15WPsuPJDCzhp2SvYZLRj8mAlT3zmoAMW)
-2. Add image files to test into `demo_image/`
-3. Run demo.py (add `--sensitive` option if you use case-sensitive model)
-```
-CUDA_VISIBLE_DEVICES=0 python3 demo.py \
---Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn \
---image_folder demo_image/ \
---saved_model TPS-ResNet-BiLSTM-Attn.pth
-```
-
-#### prediction results
-
-| demo images | [TRBA (**T**PS-**R**esNet-**B**iLSTM-**A**ttn)](https://drive.google.com/open?id=1b59rXuGGmKne1AuHnkgDzoYgKeETNMv9) | [TRBA (case-sensitive version)](https://drive.google.com/open?id=1ajONZOgiG9pEYsQ-eBmgkVbMDuHgPCaY) |
-| --- | --- | --- |
-| | available | Available |
-| | shakeshack | SHARESHACK |
-| | london | Londen |
-| | greenstead | Greenstead |
-| | toast | TOAST |
-| | merry | MERRY |
-| | underground | underground |
-| | ronaldo | RONALDO |
-| | bally | BALLY |
-| | university | UNIVERSITY |
-
-
-### Training and evaluation
-1. Train CRNN[10] model
-```
-CUDA_VISIBLE_DEVICES=0 python3 train.py \
---train_data data_lmdb_release/training --valid_data data_lmdb_release/validation \
---select_data MJ-ST --batch_ratio 0.5-0.5 \
---Transformation None --FeatureExtraction VGG --SequenceModeling BiLSTM --Prediction CTC
-```
-2. Test CRNN[10] model. If you want to evaluate IC15-2077, check [data filtering part](https://github.com/clovaai/deep-text-recognition-benchmark/blob/c27abe6b4c681e2ee0784ad966602c056a0dd3b5/dataset.py#L148).
-```
-CUDA_VISIBLE_DEVICES=0 python3 test.py \
---eval_data data_lmdb_release/evaluation --benchmark_all_eval \
---Transformation None --FeatureExtraction VGG --SequenceModeling BiLSTM --Prediction CTC \
---saved_model saved_models/None-VGG-BiLSTM-CTC-Seed1111/best_accuracy.pth
-```
-
-3. Try to train and test our best accuracy model TRBA (**T**PS-**R**esNet-**B**iLSTM-**A**ttn) also. ([download pretrained model](https://drive.google.com/drive/folders/15WPsuPJDCzhp2SvYZLRj8mAlT3zmoAMW))
-```
-CUDA_VISIBLE_DEVICES=0 python3 train.py \
---train_data data_lmdb_release/training --valid_data data_lmdb_release/validation \
---select_data MJ-ST --batch_ratio 0.5-0.5 \
---Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn
-```
-```
-CUDA_VISIBLE_DEVICES=0 python3 test.py \
---eval_data data_lmdb_release/evaluation --benchmark_all_eval \
---Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn \
---saved_model saved_models/TPS-ResNet-BiLSTM-Attn-Seed1111/best_accuracy.pth
-```
-
-### Arguments
-* `--train_data`: folder path to training lmdb dataset.
-* `--valid_data`: folder path to validation lmdb dataset.
-* `--eval_data`: folder path to evaluation (with test.py) lmdb dataset.
-* `--select_data`: select training data. default is MJ-ST, which means MJ and ST used as training data.
-* `--batch_ratio`: assign ratio for each selected data in the batch. default is 0.5-0.5, which means 50% of the batch is filled with MJ and the other 50% of the batch is filled ST.
-* `--data_filtering_off`: skip [data filtering](https://github.com/clovaai/deep-text-recognition-benchmark/blob/f2c54ae2a4cc787a0f5859e9fdd0e399812c76a3/dataset.py#L126-L146) when creating LmdbDataset.
-* `--Transformation`: select Transformation module [None | TPS].
-* `--FeatureExtraction`: select FeatureExtraction module [VGG | RCNN | ResNet].
-* `--SequenceModeling`: select SequenceModeling module [None | BiLSTM].
-* `--Prediction`: select Prediction module [CTC | Attn].
-* `--saved_model`: assign saved model to evaluation.
-* `--benchmark_all_eval`: evaluate with 10 evaluation dataset versions, same with Table 1 in our paper.
-
-## Download failure cases and cleansed label from [here](https://www.dropbox.com/s/5knh1gb1z593fxj/image_release_190624.zip?dl=0)
-image_release.zip contains failure case images and benchmark evaluation images with cleansed label.
-
-
-## When you need to train on your own dataset or Non-Latin language datasets.
-1. Create your own lmdb dataset.
-```
-pip3 install fire
-python3 create_lmdb_dataset.py --inputPath data/ --gtFile data/gt.txt --outputPath result/
-```
-The structure of data folder as below.
-```
-data
-├── gt.txt
-└── test
- ├── word_1.png
- ├── word_2.png
- ├── word_3.png
- └── ...
-```
-At this time, `gt.txt` should be `{imagepath}\t{label}\n`
-For example
-```
-test/word_1.png Tiredness
-test/word_2.png kills
-test/word_3.png A
-...
-```
-2. Modify `--select_data`, `--batch_ratio`, and `opt.character`, see [this issue](https://github.com/clovaai/deep-text-recognition-benchmark/issues/85).
-
-
-## Acknowledgements
-This implementation has been based on these repository [crnn.pytorch](https://github.com/meijieru/crnn.pytorch), [ocr_attention](https://github.com/marvis/ocr_attention).
-
-## Reference
-[1] M. Jaderberg, K. Simonyan, A. Vedaldi, and A. Zisserman. Synthetic data and artificial neural networks for natural scenetext recognition. In Workshop on Deep Learning, NIPS, 2014.
-[2] A. Gupta, A. Vedaldi, and A. Zisserman. Synthetic data fortext localisation in natural images. In CVPR, 2016.
-[3] D. Karatzas, F. Shafait, S. Uchida, M. Iwamura, L. G. i Big-orda, S. R. Mestre, J. Mas, D. F. Mota, J. A. Almazan, andL. P. De Las Heras. ICDAR 2013 robust reading competition. In ICDAR, pages 1484–1493, 2013.
-[4] D. Karatzas, L. Gomez-Bigorda, A. Nicolaou, S. Ghosh, A. Bagdanov, M. Iwamura, J. Matas, L. Neumann, V. R.Chandrasekhar, S. Lu, et al. ICDAR 2015 competition on ro-bust reading. In ICDAR, pages 1156–1160, 2015.
-[5] A. Mishra, K. Alahari, and C. Jawahar. Scene text recognition using higher order language priors. In BMVC, 2012.
-[6] K. Wang, B. Babenko, and S. Belongie. End-to-end scenetext recognition. In ICCV, pages 1457–1464, 2011.
-[7] S. M. Lucas, A. Panaretos, L. Sosa, A. Tang, S. Wong, andR. Young. ICDAR 2003 robust reading competitions. In ICDAR, pages 682–687, 2003.
-[8] T. Q. Phan, P. Shivakumara, S. Tian, and C. L. Tan. Recognizing text with perspective distortion in natural scenes. In ICCV, pages 569–576, 2013.
-[9] A. Risnumawan, P. Shivakumara, C. S. Chan, and C. L. Tan. A robust arbitrary text detection system for natural scene images. In ESWA, volume 41, pages 8027–8048, 2014.
-[10] B. Shi, X. Bai, and C. Yao. An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition. In TPAMI, volume 39, pages2298–2304. 2017.
-
-## Links
-- WebDemo : https://demo.ocr.clova.ai/
-Combination of Clova AI detection and recognition, additional/advanced features used for KOR/JPN.
-- Repo of detection : https://github.com/clovaai/CRAFT-pytorch
-
-## Citation
-Please consider citing this work in your publications if it helps your research.
-```
-@inproceedings{baek2019STRcomparisons,
- title={What Is Wrong With Scene Text Recognition Model Comparisons? Dataset and Model Analysis},
- author={Baek, Jeonghun and Kim, Geewook and Lee, Junyeop and Park, Sungrae and Han, Dongyoon and Yun, Sangdoo and Oh, Seong Joon and Lee, Hwalsuk},
- booktitle = {International Conference on Computer Vision (ICCV)},
- year={2019},
- pubstate={published},
- tppubtype={inproceedings}
-}
-```
-
-## Contact
-Feel free to contact us if there is any question:
-for code/paper Jeonghun Baek ku21fang@gmail.com; for collaboration hwalsuk.lee@navercorp.com (our team leader).
-
-## License
-Copyright (c) 2019-present NAVER Corp.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-
diff --git a/dtrb/create_lmdb_dataset.py b/dtrb/create_lmdb_dataset.py
deleted file mode 100644
index a58d137..0000000
--- a/dtrb/create_lmdb_dataset.py
+++ /dev/null
@@ -1,87 +0,0 @@
-""" a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """
-
-import fire
-import os
-import lmdb
-import cv2
-
-import numpy as np
-
-
-def checkImageIsValid(imageBin):
- if imageBin is None:
- return False
- imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
- img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
- imgH, imgW = img.shape[0], img.shape[1]
- if imgH * imgW == 0:
- return False
- return True
-
-
-def writeCache(env, cache):
- with env.begin(write=True) as txn:
- for k, v in cache.items():
- txn.put(k, v)
-
-
-def createDataset(inputPath, gtFile, outputPath, checkValid=True):
- """
- Create LMDB dataset for training and evaluation.
- ARGS:
- inputPath : input folder path where starts imagePath
- outputPath : LMDB output path
- gtFile : list of image path and label
- checkValid : if true, check the validity of every image
- """
- os.makedirs(outputPath, exist_ok=True)
- env = lmdb.open(outputPath, map_size=1099511627776)
- cache = {}
- cnt = 1
-
- with open(gtFile, 'r', encoding='utf-8') as data:
- datalist = data.readlines()
-
- nSamples = len(datalist)
- for i in range(nSamples):
- imagePath, label = datalist[i].strip('\n').split('\t')
- imagePath = os.path.join(inputPath, imagePath)
-
- # # only use alphanumeric data
- # if re.search('[^a-zA-Z0-9]', label):
- # continue
-
- if not os.path.exists(imagePath):
- print('%s does not exist' % imagePath)
- continue
- with open(imagePath, 'rb') as f:
- imageBin = f.read()
- if checkValid:
- try:
- if not checkImageIsValid(imageBin):
- print('%s is not a valid image' % imagePath)
- continue
- except:
- print('error occured', i)
- with open(outputPath + '/error_image_log.txt', 'a') as log:
- log.write('%s-th image data occured error\n' % str(i))
- continue
-
- imageKey = 'image-%09d'.encode() % cnt
- labelKey = 'label-%09d'.encode() % cnt
- cache[imageKey] = imageBin
- cache[labelKey] = label.encode()
-
- if cnt % 1000 == 0:
- writeCache(env, cache)
- cache = {}
- print('Written %d / %d' % (cnt, nSamples))
- cnt += 1
- nSamples = cnt-1
- cache['num-samples'.encode()] = str(nSamples).encode()
- writeCache(env, cache)
- print('Created dataset with %d samples' % nSamples)
-
-
-if __name__ == '__main__':
- fire.Fire(createDataset)
diff --git a/dtrb/dataset.py b/dtrb/dataset.py
deleted file mode 100644
index d98224a..0000000
--- a/dtrb/dataset.py
+++ /dev/null
@@ -1,339 +0,0 @@
-import os
-import sys
-import re
-import six
-import math
-import lmdb
-import torch
-
-from natsort import natsorted
-from PIL import Image
-import numpy as np
-from torch.utils.data import Dataset, ConcatDataset, Subset
-from torch._utils import _accumulate
-import torchvision.transforms as transforms
-
-
-class Batch_Balanced_Dataset(object):
-
- def __init__(self, opt):
- """
- Modulate the data ratio in the batch.
- For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5",
- the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST.
- """
- log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
- dashed_line = '-' * 80
- print(dashed_line)
- log.write(dashed_line + '\n')
- print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
- log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n')
- assert len(opt.select_data) == len(opt.batch_ratio)
-
- _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
- self.data_loader_list = []
- self.dataloader_iter_list = []
- batch_size_list = []
- Total_batch_size = 0
- for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio):
- _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
- print(dashed_line)
- log.write(dashed_line + '\n')
- _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d])
- total_number_dataset = len(_dataset)
- log.write(_dataset_log)
-
- """
- The total number of data can be modified with opt.total_data_usage_ratio.
- ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage.
- See 4.2 section in our paper.
- """
- number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio))
- dataset_split = [number_dataset, total_number_dataset - number_dataset]
- indices = range(total_number_dataset)
- _dataset, _ = [Subset(_dataset, indices[offset - length:offset])
- for offset, length in zip(_accumulate(dataset_split), dataset_split)]
- selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n'
- selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}'
- print(selected_d_log)
- log.write(selected_d_log + '\n')
- batch_size_list.append(str(_batch_size))
- Total_batch_size += _batch_size
-
- _data_loader = torch.utils.data.DataLoader(
- _dataset, batch_size=_batch_size,
- shuffle=True,
- num_workers=int(opt.workers),
- collate_fn=_AlignCollate, pin_memory=True)
- self.data_loader_list.append(_data_loader)
- self.dataloader_iter_list.append(iter(_data_loader))
-
- Total_batch_size_log = f'{dashed_line}\n'
- batch_size_sum = '+'.join(batch_size_list)
- Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n'
- Total_batch_size_log += f'{dashed_line}'
- opt.batch_size = Total_batch_size
-
- print(Total_batch_size_log)
- log.write(Total_batch_size_log + '\n')
- log.close()
-
- def get_batch(self):
- balanced_batch_images = []
- balanced_batch_texts = []
-
- for i, data_loader_iter in enumerate(self.dataloader_iter_list):
- try:
- image, text = data_loader_iter.next()
- balanced_batch_images.append(image)
- balanced_batch_texts += text
- except StopIteration:
- self.dataloader_iter_list[i] = iter(self.data_loader_list[i])
- image, text = self.dataloader_iter_list[i].next()
- balanced_batch_images.append(image)
- balanced_batch_texts += text
- except ValueError:
- pass
-
- balanced_batch_images = torch.cat(balanced_batch_images, 0)
-
- return balanced_batch_images, balanced_batch_texts
-
-
-def hierarchical_dataset(root, opt, select_data='/'):
- """ select_data='/' contains all sub-directory of root directory """
- dataset_list = []
- dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}'
- print(dataset_log)
- dataset_log += '\n'
- for dirpath, dirnames, filenames in os.walk(root+'/'):
- if not dirnames:
- select_flag = False
- for selected_d in select_data:
- if selected_d in dirpath:
- select_flag = True
- break
-
- if select_flag:
- dataset = LmdbDataset(dirpath, opt)
- sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}'
- print(sub_dataset_log)
- dataset_log += f'{sub_dataset_log}\n'
- dataset_list.append(dataset)
-
- concatenated_dataset = ConcatDataset(dataset_list)
-
- return concatenated_dataset, dataset_log
-
-
-class LmdbDataset(Dataset):
-
- def __init__(self, root, opt):
-
- self.root = root
- self.opt = opt
- self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
- if not self.env:
- print('cannot create lmdb from %s' % (root))
- sys.exit(0)
-
- with self.env.begin(write=False) as txn:
- nSamples = int(txn.get('num-samples'.encode()))
- self.nSamples = nSamples
-
- if self.opt.data_filtering_off:
- # for fast check or benchmark evaluation with no filtering
- self.filtered_index_list = [index + 1 for index in range(self.nSamples)]
- else:
- """ Filtering part
- If you want to evaluate IC15-2077 & CUTE datasets which have special character labels,
- use --data_filtering_off and only evaluate on alphabets and digits.
- see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L190-L192
-
- And if you want to evaluate them with the model trained with --sensitive option,
- use --sensitive and --data_filtering_off,
- see https://github.com/clovaai/deep-text-recognition-benchmark/blob/dff844874dbe9e0ec8c5a52a7bd08c7f20afe704/test.py#L137-L144
- """
- self.filtered_index_list = []
- for index in range(self.nSamples):
- index += 1 # lmdb starts with 1
- label_key = 'label-%09d'.encode() % index
- label = txn.get(label_key).decode('utf-8')
-
- if len(label) > self.opt.batch_max_length:
- # print(f'The length of the label is longer than max_length: length
- # {len(label)}, {label} in dataset {self.root}')
- continue
-
- # By default, images containing characters which are not in opt.character are filtered.
- # You can add [UNK] token to `opt.character` in utils.py instead of this filtering.
- out_of_char = f'[^{self.opt.character}]'
- if re.search(out_of_char, label.lower()):
- continue
-
- self.filtered_index_list.append(index)
-
- self.nSamples = len(self.filtered_index_list)
-
- def __len__(self):
- return self.nSamples
-
- def __getitem__(self, index):
- assert index <= len(self), 'index range error'
- index = self.filtered_index_list[index]
-
- with self.env.begin(write=False) as txn:
- label_key = 'label-%09d'.encode() % index
- label = txn.get(label_key).decode('utf-8')
- img_key = 'image-%09d'.encode() % index
- imgbuf = txn.get(img_key)
-
- buf = six.BytesIO()
- buf.write(imgbuf)
- buf.seek(0)
- try:
- if self.opt.rgb:
- img = Image.open(buf).convert('RGB') # for color image
- else:
- img = Image.open(buf).convert('L')
-
- except IOError:
- print(f'Corrupted image for {index}')
- # make dummy image and dummy label for corrupted image.
- if self.opt.rgb:
- img = Image.new('RGB', (self.opt.imgW, self.opt.imgH))
- else:
- img = Image.new('L', (self.opt.imgW, self.opt.imgH))
- label = '[dummy_label]'
-
- if not self.opt.sensitive:
- label = label.lower()
-
- # We only train and evaluate on alphanumerics (or pre-defined character set in train.py)
- out_of_char = f'[^{self.opt.character}]'
- label = re.sub(out_of_char, '', label)
-
- return (img, label)
-
-
-class RawDataset(Dataset):
-
- def __init__(self, root, opt):
- self.opt = opt
- self.image_path_list = []
- for dirpath, dirnames, filenames in os.walk(root):
- for name in filenames:
- _, ext = os.path.splitext(name)
- ext = ext.lower()
- if ext == '.jpg' or ext == '.jpeg' or ext == '.png':
- self.image_path_list.append(os.path.join(dirpath, name))
-
- self.image_path_list = natsorted(self.image_path_list)
- self.nSamples = len(self.image_path_list)
-
- def __len__(self):
- return self.nSamples
-
- def __getitem__(self, index):
-
- try:
- if self.opt.rgb:
- img = Image.open(self.image_path_list[index]).convert('RGB') # for color image
- else:
- img = Image.open(self.image_path_list[index]).convert('L')
-
- except IOError:
- print(f'Corrupted image for {index}')
- # make dummy image and dummy label for corrupted image.
- if self.opt.rgb:
- img = Image.new('RGB', (self.opt.imgW, self.opt.imgH))
- else:
- img = Image.new('L', (self.opt.imgW, self.opt.imgH))
-
- return (img, self.image_path_list[index])
-
-
-class ResizeNormalize(object):
-
- def __init__(self, size, interpolation=Image.BICUBIC):
- self.size = size
- self.interpolation = interpolation
- self.toTensor = transforms.ToTensor()
-
- def __call__(self, img):
- img = img.resize(self.size, self.interpolation)
- img = self.toTensor(img)
- img.sub_(0.5).div_(0.5)
- return img
-
-
-class NormalizePAD(object):
-
- def __init__(self, max_size, PAD_type='right'):
- self.toTensor = transforms.ToTensor()
- self.max_size = max_size
- self.max_width_half = math.floor(max_size[2] / 2)
- self.PAD_type = PAD_type
-
- def __call__(self, img):
- img = self.toTensor(img)
- img.sub_(0.5).div_(0.5)
- c, h, w = img.size()
- Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
- Pad_img[:, :, :w] = img # right pad
- if self.max_size[2] != w: # add border Pad
- Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
-
- return Pad_img
-
-
-class AlignCollate(object):
-
- def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False):
- self.imgH = imgH
- self.imgW = imgW
- self.keep_ratio_with_pad = keep_ratio_with_pad
-
- def __call__(self, batch):
- batch = filter(lambda x: x is not None, batch)
- images, labels = zip(*batch)
-
- if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper
- resized_max_w = self.imgW
- input_channel = 3 if images[0].mode == 'RGB' else 1
- transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
-
- resized_images = []
- for image in images:
- w, h = image.size
- ratio = w / float(h)
- if math.ceil(self.imgH * ratio) > self.imgW:
- resized_w = self.imgW
- else:
- resized_w = math.ceil(self.imgH * ratio)
-
- resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
- resized_images.append(transform(resized_image))
- # resized_image.save('./image_test/%d_test.jpg' % w)
-
- image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
-
- else:
- transform = ResizeNormalize((self.imgW, self.imgH))
- image_tensors = [transform(image) for image in images]
- image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
-
- return image_tensors, labels
-
-
-def tensor2im(image_tensor, imtype=np.uint8):
- image_numpy = image_tensor.cpu().float().numpy()
- if image_numpy.shape[0] == 1:
- image_numpy = np.tile(image_numpy, (3, 1, 1))
- image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
- return image_numpy.astype(imtype)
-
-
-def save_image(image_numpy, image_path):
- image_pil = Image.fromarray(image_numpy)
- image_pil.save(image_path)
diff --git a/dtrb/demo.py b/dtrb/demo.py
deleted file mode 100644
index d3d68fe..0000000
--- a/dtrb/demo.py
+++ /dev/null
@@ -1,148 +0,0 @@
-import string
-import argparse
-
-import torch
-import torch.backends.cudnn as cudnn
-import torch.utils.data
-import torch.nn.functional as F
-
-from utils import CTCLabelConverter, AttnLabelConverter
-from dataset import RawDataset, AlignCollate
-from model import Model
-
-import os
-from datetime import datetime
-import pandas as pd
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-
-def demo(opt):
- """ model configuration """
- if 'CTC' in opt.Prediction:
- converter = CTCLabelConverter(opt.character)
- else:
- converter = AttnLabelConverter(opt.character)
- opt.num_class = len(converter.character)
-
- if opt.rgb:
- opt.input_channel = 3
- model = Model(opt)
- print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
- opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
- opt.SequenceModeling, opt.Prediction)
- model = torch.nn.DataParallel(model).to(device)
-
- # load model
- print('loading pretrained model from %s' % opt.saved_model)
- model.load_state_dict(torch.load(opt.saved_model, map_location=device))
- print(opt.image_folder)
- # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
- AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
- demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset
- demo_loader = torch.utils.data.DataLoader(
- demo_data, batch_size=opt.batch_size,
- shuffle=False,
- num_workers=int(opt.workers),
- collate_fn=AlignCollate_demo, pin_memory=True)
-
- file_list = os.listdir(opt.image_folder)
- file_jpg = len([file for file in file_list if file.endswith('.jpg')])
- cnt = 1
- date_flag=datetime.today().strftime("%Y%m%d")
- f = open(opt.image_folder + f'\\{date_flag}.csv', 'w', encoding='utf-8-sig', newline='')
- f.write('날짜,흥,진용,정답,검증,Score\n')
- f.close()
-
- # predict
- model.eval()
- with torch.no_grad():
- for image_tensors, image_path_list in demo_loader:
- batch_size = image_tensors.size(0)
- image = image_tensors.to(device)
- # For max length prediction
- length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
- text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
-
- if 'CTC' in opt.Prediction:
- preds = model(image, text_for_pred)
-
- # Select max probabilty (greedy decoding) then decode index to character
- preds_size = torch.IntTensor([preds.size(1)] * batch_size)
- _, preds_index = preds.max(2)
- # preds_index = preds_index.view(-1)
- preds_str = converter.decode(preds_index, preds_size)
-
- else:
- preds = model(image, text_for_pred, is_train=False)
-
- # select max probabilty (greedy decoding) then decode index to character
- _, preds_index = preds.max(2)
- preds_str = converter.decode(preds_index, length_for_pred)
-
- # log = open(f'./log_demo_result.txt', 'a')
- dashed_line = '-' * 80
- head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'
-
- # print(f'{dashed_line}\n{head}\n{dashed_line}')
- # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')
- f = open(opt.image_folder + f'\\{date_flag}.csv', 'a', encoding='utf-8-sig', newline='')
-
-
- preds_prob = F.softmax(preds, dim=2)
- preds_max_prob, _ = preds_prob.max(dim=2)
- for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
- if 'Attn' in opt.Prediction:
- pred_EOS = pred.find('[s]')
- pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
- pred_max_prob = pred_max_prob[:pred_EOS]
-
- # calculate confidence score (= multiply of pred_max_prob)
- confidence_score = pred_max_prob.cumprod(dim=0)[-1]
-
- # print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
- # log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')
- cnt = cnt +1
- origin_name = img_name.split("\\")
- origin_name = origin_name[-1]
- f.write(origin_name[:16] +','+origin_name[16:-4] + ','+ pred + ',' + ',' + ','+ f'{confidence_score:0.4f}'+ ',' +'\n')
- f.close()
- # log.close()
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--image_folder', required=True, help='path to image_folder which contains text images')
- parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
- parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
- parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation")
- """ Data processing """
- parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
- parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
- parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
- parser.add_argument('--rgb', action='store_true', help='use rgb input')
- parser.add_argument('--character', type=str, default='0123456789가나다라마거너더러머버서어저고노도로모보소오조구누두루무부수우주아바사자하허호배', help='character label')
- parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
- parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
- """ Model Architecture """
- parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
- parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
- parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
- parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
- parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
- parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
- parser.add_argument('--output_channel', type=int, default=256,
- help='the number of output channel of Feature extractor')
- parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
-
- opt = parser.parse_args()
- print(opt)
-
- """ vocab / character number configuration """
- if opt.sensitive:
- opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
-
- cudnn.benchmark = True
- cudnn.deterministic = True
- opt.num_gpu = torch.cuda.device_count()
-
- demo(opt)
- print("작업 완료")
diff --git a/dtrb/lpr.py b/dtrb/lpr.py
deleted file mode 100644
index d4d49fb..0000000
--- a/dtrb/lpr.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import string
-import argparse
-import types
-
-import torch
-import torch.backends.cudnn as cudnn
-import torch.utils.data
-import torch.nn.functional as F
-
-from dtrb.utils import CTCLabelConverter, AttnLabelConverter
-from dtrb.dataset import RawDataset, AlignCollate
-from dtrb.OCRmodel import Model
-
-import os
-from datetime import datetime
-import pandas as pd
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-class W_LPR(object):
-
- def load_ocr():
-
- opt = argparse.Namespace(workers = 4,batch_size = 192,saved_model = f'dtrb/models/best_accuracy.pth',batch_max_length = 25,
- imgH = 32, imgW = 100, character = '0123456789가나다라마거너더러머버서어저고노도로모보소오조구누두루무부수우주아바사자하허호배',
- Transformation = 'TPS', FeatureExtraction = 'VGG', SequenceModeling = 'BiLSTM', Prediction = 'CTC',
- num_fiducial = 20, input_channel = 1, output_channel = 256,hidden_size = 256)
-
- model = Model(opt)
- model = torch.nn.DataParallel(model).to(device)
-
- # load model
- print('모델 불러오는중...')
- model.load_state_dict(torch.load(r'models/best_accuracy.pth', map_location=device))
-
- return model.eval()
-
- def read_ocr(model,img):
- length_for_pred = torch.IntTensor([25] * 192).to(device)
- text_for_pred = torch.LongTensor(192, 26).fill_(0).to(device)
- preds = model(img, text_for_pred)
- preds_size = torch.IntTensor([preds.size(1)] * 192)
- _, preds_index = preds.max(2)
-
- preds_str = CTCLabelConverter.decode((preds_index, length_for_pred))
- return preds_str
-
-class CTCLabelConverter(object):
- """ Convert between text-label and text-index """
-
- def __init__(self, character):
- # character (str): set of the possible characters.
- dict_character = list(character)
-
- self.dict = {}
- for i, char in enumerate(dict_character):
- # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
- self.dict[char] = i + 1
-
- self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
-
- def encode(self, text, batch_max_length=25):
- """convert text-label into text-index.
- input:
- text: text labels of each image. [batch_size]
- batch_max_length: max length of text label in the batch. 25 by default
-
- output:
- text: text index for CTCLoss. [batch_size, batch_max_length]
- length: length of each text. [batch_size]
- """
- length = [len(s) for s in text]
-
- # The index used for padding (=0) would not affect the CTC loss calculation.
- batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0)
- for i, t in enumerate(text):
- text = list(t)
- text = [self.dict[char] for char in text]
- batch_text[i][:len(text)] = torch.LongTensor(text)
- return (batch_text.to(device), torch.IntTensor(length).to(device))
-
- def decode(self, text_index, length):
- """ convert text-index into text-label. """
- texts = []
- for index, l in enumerate(length):
- t = text_index[index, :]
-
- char_list = []
- for i in range(l):
- if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
- char_list.append(self.character[t[i]])
- text = ''.join(char_list)
-
- texts.append(text)
- return texts
\ No newline at end of file
diff --git a/dtrb/modules/feature_extraction.py b/dtrb/modules/feature_extraction.py
deleted file mode 100644
index b5f3004..0000000
--- a/dtrb/modules/feature_extraction.py
+++ /dev/null
@@ -1,246 +0,0 @@
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class VGG_FeatureExtractor(nn.Module):
- """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
-
- def __init__(self, input_channel, output_channel=512):
- super(VGG_FeatureExtractor, self).__init__()
- self.output_channel = [int(output_channel / 8), int(output_channel / 4),
- int(output_channel / 2), output_channel] # [64, 128, 256, 512]
- self.ConvNet = nn.Sequential(
- nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
- nn.MaxPool2d(2, 2), # 64x16x50
- nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
- nn.MaxPool2d(2, 2), # 128x8x25
- nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25
- nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
- nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
- nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
- nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25
- nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
- nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
- nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
- nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
-
- def forward(self, input):
- return self.ConvNet(input)
-
-
-class RCNN_FeatureExtractor(nn.Module):
- """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """
-
- def __init__(self, input_channel, output_channel=512):
- super(RCNN_FeatureExtractor, self).__init__()
- self.output_channel = [int(output_channel / 8), int(output_channel / 4),
- int(output_channel / 2), output_channel] # [64, 128, 256, 512]
- self.ConvNet = nn.Sequential(
- nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
- nn.MaxPool2d(2, 2), # 64 x 16 x 50
- GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1),
- nn.MaxPool2d(2, 2), # 64 x 8 x 25
- GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1),
- nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26
- GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1),
- nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27
- nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False),
- nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26
-
- def forward(self, input):
- return self.ConvNet(input)
-
-
-class ResNet_FeatureExtractor(nn.Module):
- """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
-
- def __init__(self, input_channel, output_channel=512):
- super(ResNet_FeatureExtractor, self).__init__()
- self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3])
-
- def forward(self, input):
- return self.ConvNet(input)
-
-
-# For Gated RCNN
-class GRCL(nn.Module):
-
- def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad):
- super(GRCL, self).__init__()
- self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False)
- self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False)
- self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False)
- self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False)
-
- self.BN_x_init = nn.BatchNorm2d(output_channel)
-
- self.num_iteration = num_iteration
- self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)]
- self.GRCL = nn.Sequential(*self.GRCL)
-
- def forward(self, input):
- """ The input of GRCL is consistant over time t, which is denoted by u(0)
- thus wgf_u / wf_u is also consistant over time t.
- """
- wgf_u = self.wgf_u(input)
- wf_u = self.wf_u(input)
- x = F.relu(self.BN_x_init(wf_u))
-
- for i in range(self.num_iteration):
- x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x))
-
- return x
-
-
-class GRCL_unit(nn.Module):
-
- def __init__(self, output_channel):
- super(GRCL_unit, self).__init__()
- self.BN_gfu = nn.BatchNorm2d(output_channel)
- self.BN_grx = nn.BatchNorm2d(output_channel)
- self.BN_fu = nn.BatchNorm2d(output_channel)
- self.BN_rx = nn.BatchNorm2d(output_channel)
- self.BN_Gx = nn.BatchNorm2d(output_channel)
-
- def forward(self, wgf_u, wgr_x, wf_u, wr_x):
- G_first_term = self.BN_gfu(wgf_u)
- G_second_term = self.BN_grx(wgr_x)
- G = F.sigmoid(G_first_term + G_second_term)
-
- x_first_term = self.BN_fu(wf_u)
- x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G)
- x = F.relu(x_first_term + x_second_term)
-
- return x
-
-
-class BasicBlock(nn.Module):
- expansion = 1
-
- def __init__(self, inplanes, planes, stride=1, downsample=None):
- super(BasicBlock, self).__init__()
- self.conv1 = self._conv3x3(inplanes, planes)
- self.bn1 = nn.BatchNorm2d(planes)
- self.conv2 = self._conv3x3(planes, planes)
- self.bn2 = nn.BatchNorm2d(planes)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
-
- def _conv3x3(self, in_planes, out_planes, stride=1):
- "3x3 convolution with padding"
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=1, bias=False)
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
- out += residual
- out = self.relu(out)
-
- return out
-
-
-class ResNet(nn.Module):
-
- def __init__(self, input_channel, output_channel, block, layers):
- super(ResNet, self).__init__()
-
- self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
-
- self.inplanes = int(output_channel / 8)
- self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
- kernel_size=3, stride=1, padding=1, bias=False)
- self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
- self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
- kernel_size=3, stride=1, padding=1, bias=False)
- self.bn0_2 = nn.BatchNorm2d(self.inplanes)
- self.relu = nn.ReLU(inplace=True)
-
- self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
- self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
- self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
- 0], kernel_size=3, stride=1, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
-
- self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
- self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
- self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
- 1], kernel_size=3, stride=1, padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
-
- self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
- self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
- self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
- 2], kernel_size=3, stride=1, padding=1, bias=False)
- self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
-
- self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
- self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
- 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
- self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
- self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
- 3], kernel_size=2, stride=1, padding=0, bias=False)
- self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
-
- def _make_layer(self, block, planes, blocks, stride=1):
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- nn.Conv2d(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False),
- nn.BatchNorm2d(planes * block.expansion),
- )
-
- layers = []
- layers.append(block(self.inplanes, planes, stride, downsample))
- self.inplanes = planes * block.expansion
- for i in range(1, blocks):
- layers.append(block(self.inplanes, planes))
-
- return nn.Sequential(*layers)
-
- def forward(self, x):
- x = self.conv0_1(x)
- x = self.bn0_1(x)
- x = self.relu(x)
- x = self.conv0_2(x)
- x = self.bn0_2(x)
- x = self.relu(x)
-
- x = self.maxpool1(x)
- x = self.layer1(x)
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
-
- x = self.maxpool2(x)
- x = self.layer2(x)
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu(x)
-
- x = self.maxpool3(x)
- x = self.layer3(x)
- x = self.conv3(x)
- x = self.bn3(x)
- x = self.relu(x)
-
- x = self.layer4(x)
- x = self.conv4_1(x)
- x = self.bn4_1(x)
- x = self.relu(x)
- x = self.conv4_2(x)
- x = self.bn4_2(x)
- x = self.relu(x)
-
- return x
diff --git a/dtrb/modules/prediction.py b/dtrb/modules/prediction.py
deleted file mode 100644
index b6c3cb3..0000000
--- a/dtrb/modules/prediction.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-
-class Attention(nn.Module):
-
- def __init__(self, input_size, hidden_size, num_classes):
- super(Attention, self).__init__()
- self.attention_cell = AttentionCell(input_size, hidden_size, num_classes)
- self.hidden_size = hidden_size
- self.num_classes = num_classes
- self.generator = nn.Linear(hidden_size, num_classes)
-
- def _char_to_onehot(self, input_char, onehot_dim=38):
- input_char = input_char.unsqueeze(1)
- batch_size = input_char.size(0)
- one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
- one_hot = one_hot.scatter_(1, input_char, 1)
- return one_hot
-
- def forward(self, batch_H, text, is_train=True, batch_max_length=25):
- """
- input:
- batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels]
- text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
- output: probability distribution at each step [batch_size x num_steps x num_classes]
- """
- batch_size = batch_H.size(0)
- num_steps = batch_max_length + 1 # +1 for [s] at end of sentence.
-
- output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device)
- hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
- torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device))
-
- if is_train:
- for i in range(num_steps):
- # one-hot vectors for a i-th char. in a batch
- char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes)
- # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1})
- hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
- output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell)
- probs = self.generator(output_hiddens)
-
- else:
- targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token
- probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device)
-
- for i in range(num_steps):
- char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes)
- hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
- probs_step = self.generator(hidden[0])
- probs[:, i, :] = probs_step
- _, next_input = probs_step.max(1)
- targets = next_input
-
- return probs # batch_size x num_steps x num_classes
-
-
-class AttentionCell(nn.Module):
-
- def __init__(self, input_size, hidden_size, num_embeddings):
- super(AttentionCell, self).__init__()
- self.i2h = nn.Linear(input_size, hidden_size, bias=False)
- self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias
- self.score = nn.Linear(hidden_size, 1, bias=False)
- self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
- self.hidden_size = hidden_size
-
- def forward(self, prev_hidden, batch_H, char_onehots):
- # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
- batch_H_proj = self.i2h(batch_H)
- prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
- e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
-
- alpha = F.softmax(e, dim=1)
- context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
- concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding)
- cur_hidden = self.rnn(concat_context, prev_hidden)
- return cur_hidden, alpha
diff --git a/dtrb/modules/sequence_modeling.py b/dtrb/modules/sequence_modeling.py
deleted file mode 100644
index af32c59..0000000
--- a/dtrb/modules/sequence_modeling.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import torch.nn as nn
-
-
-class BidirectionalLSTM(nn.Module):
-
- def __init__(self, input_size, hidden_size, output_size):
- super(BidirectionalLSTM, self).__init__()
- self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
- self.linear = nn.Linear(hidden_size * 2, output_size)
-
- def forward(self, input):
- """
- input : visual feature [batch_size x T x input_size]
- output : contextual feature [batch_size x T x output_size]
- """
- self.rnn.flatten_parameters()
- recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
- output = self.linear(recurrent) # batch_size x T x output_size
- return output
diff --git a/dtrb/modules/transformation.py b/dtrb/modules/transformation.py
deleted file mode 100644
index 875d1ae..0000000
--- a/dtrb/modules/transformation.py
+++ /dev/null
@@ -1,164 +0,0 @@
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-
-class TPS_SpatialTransformerNetwork(nn.Module):
- """ Rectification Network of RARE, namely TPS based STN """
-
- def __init__(self, F, I_size, I_r_size, I_channel_num=1):
- """ Based on RARE TPS
- input:
- batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
- I_size : (height, width) of the input image I
- I_r_size : (height, width) of the rectified image I_r
- I_channel_num : the number of channels of the input image I
- output:
- batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
- """
- super(TPS_SpatialTransformerNetwork, self).__init__()
- self.F = F
- self.I_size = I_size
- self.I_r_size = I_r_size # = (I_r_height, I_r_width)
- self.I_channel_num = I_channel_num
- self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
- self.GridGenerator = GridGenerator(self.F, self.I_r_size)
-
- def forward(self, batch_I):
- batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
- build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
- build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
-
- if torch.__version__ > "1.2.0":
- batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
- else:
- batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
-
- return batch_I_r
-
-
-class LocalizationNetwork(nn.Module):
- """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
-
- def __init__(self, F, I_channel_num):
- super(LocalizationNetwork, self).__init__()
- self.F = F
- self.I_channel_num = I_channel_num
- self.conv = nn.Sequential(
- nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
- bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
- nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
- nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
- nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
- nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
- nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
- nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
- nn.AdaptiveAvgPool2d(1) # batch_size x 512
- )
-
- self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
- self.localization_fc2 = nn.Linear(256, self.F * 2)
-
- # Init fc2 in LocalizationNetwork
- self.localization_fc2.weight.data.fill_(0)
- """ see RARE paper Fig. 6 (a) """
- ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
- ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
- ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
- ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
- ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
- initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
- self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)
-
- def forward(self, batch_I):
- """
- input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
- output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
- """
- batch_size = batch_I.size(0)
- features = self.conv(batch_I).view(batch_size, -1)
- batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2)
- return batch_C_prime
-
-
-class GridGenerator(nn.Module):
- """ Grid Generator of RARE, which produces P_prime by multipling T with P """
-
- def __init__(self, F, I_r_size):
- """ Generate P_hat and inv_delta_C for later """
- super(GridGenerator, self).__init__()
- self.eps = 1e-6
- self.I_r_height, self.I_r_width = I_r_size
- self.F = F
- self.C = self._build_C(self.F) # F x 2
- self.P = self._build_P(self.I_r_width, self.I_r_height)
- ## for multi-gpu, you need register buffer
- self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3
- self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3
- ## for fine-tuning with different image width, you may use below instead of self.register_buffer
- #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3
- #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3
-
- def _build_C(self, F):
- """ Return coordinates of fiducial points in I_r; C """
- ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
- ctrl_pts_y_top = -1 * np.ones(int(F / 2))
- ctrl_pts_y_bottom = np.ones(int(F / 2))
- ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
- ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
- C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
- return C # F x 2
-
- def _build_inv_delta_C(self, F, C):
- """ Return inv_delta_C which is needed to calculate T """
- hat_C = np.zeros((F, F), dtype=float) # F x F
- for i in range(0, F):
- for j in range(i, F):
- r = np.linalg.norm(C[i] - C[j])
- hat_C[i, j] = r
- hat_C[j, i] = r
- np.fill_diagonal(hat_C, 1)
- hat_C = (hat_C ** 2) * np.log(hat_C)
- # print(C.shape, hat_C.shape)
- delta_C = np.concatenate( # F+3 x F+3
- [
- np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
- np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
- np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
- ],
- axis=0
- )
- inv_delta_C = np.linalg.inv(delta_C)
- return inv_delta_C # F+3 x F+3
-
- def _build_P(self, I_r_width, I_r_height):
- I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width
- I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height
- P = np.stack( # self.I_r_width x self.I_r_height x 2
- np.meshgrid(I_r_grid_x, I_r_grid_y),
- axis=2
- )
- return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
-
- def _build_P_hat(self, F, C, P):
- n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
- P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
- C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
- P_diff = P_tile - C_tile # n x F x 2
- rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
- rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F
- P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
- return P_hat # n x F+3
-
- def build_P_prime(self, batch_C_prime):
- """ Generate Grid from batch_C_prime [batch_size x F x 2] """
- batch_size = batch_C_prime.size(0)
- batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
- batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
- batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
- batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2
- batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
- batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
- return batch_P_prime # batch_size x n x 2
diff --git a/dtrb/test.py b/dtrb/test.py
deleted file mode 100644
index 68ac5da..0000000
--- a/dtrb/test.py
+++ /dev/null
@@ -1,282 +0,0 @@
-import os
-import time
-import string
-import argparse
-import re
-
-import torch
-import torch.backends.cudnn as cudnn
-import torch.utils.data
-import torch.nn.functional as F
-import numpy as np
-from nltk.metrics.distance import edit_distance
-
-from utils import CTCLabelConverter, AttnLabelConverter, Averager
-from dataset import hierarchical_dataset, AlignCollate
-from model import Model
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-
-def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False):
- """ evaluation with 10 benchmark evaluation datasets """
- # The evaluation datasets, dataset order is same with Table 1 in our paper.
- eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857',
- 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
-
- # # To easily compute the total accuracy of our paper.
- # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867',
- # 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80']
-
- if calculate_infer_time:
- evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image.
- else:
- evaluation_batch_size = opt.batch_size
-
- list_accuracy = []
- total_forward_time = 0
- total_evaluation_data_number = 0
- total_correct_number = 0
- log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a')
- dashed_line = '-' * 80
- print(dashed_line)
- log.write(dashed_line + '\n')
- for eval_data in eval_data_list:
- eval_data_path = os.path.join(opt.eval_data, eval_data)
- AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
- eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt)
- evaluation_loader = torch.utils.data.DataLoader(
- eval_data, batch_size=evaluation_batch_size,
- shuffle=False,
- num_workers=int(opt.workers),
- collate_fn=AlignCollate_evaluation, pin_memory=True)
-
- _, accuracy_by_best_model, norm_ED_by_best_model, _, _, _, infer_time, length_of_data = validation(
- model, criterion, evaluation_loader, converter, opt)
- list_accuracy.append(f'{accuracy_by_best_model:0.3f}')
- total_forward_time += infer_time
- total_evaluation_data_number += len(eval_data)
- total_correct_number += accuracy_by_best_model * length_of_data
- log.write(eval_data_log)
- print(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}')
- log.write(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}\n')
- print(dashed_line)
- log.write(dashed_line + '\n')
-
- averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000
- total_accuracy = total_correct_number / total_evaluation_data_number
- params_num = sum([np.prod(p.size()) for p in model.parameters()])
-
- evaluation_log = 'accuracy: '
- for name, accuracy in zip(eval_data_list, list_accuracy):
- evaluation_log += f'{name}: {accuracy}\t'
- evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t'
- evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}'
- print(evaluation_log)
- log.write(evaluation_log + '\n')
- log.close()
-
- return None
-
-
-def validation(model, criterion, evaluation_loader, converter, opt):
- """ validation or evaluation """
- n_correct = 0
- norm_ED = 0
- length_of_data = 0
- infer_time = 0
- valid_loss_avg = Averager()
-
- for i, (image_tensors, labels) in enumerate(evaluation_loader):
- batch_size = image_tensors.size(0)
- length_of_data = length_of_data + batch_size
- image = image_tensors.to(device)
- # For max length prediction
- length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
- text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
-
- text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
-
- start_time = time.time()
- if 'CTC' in opt.Prediction:
- preds = model(image, text_for_pred)
- forward_time = time.time() - start_time
-
- # Calculate evaluation loss for CTC deocder.
- preds_size = torch.IntTensor([preds.size(1)] * batch_size)
- # permute 'preds' to use CTCloss format
- if opt.baiduCTC:
- cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size
- else:
- cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
-
- # Select max probabilty (greedy decoding) then decode index to character
- if opt.baiduCTC:
- _, preds_index = preds.max(2)
- preds_index = preds_index.view(-1)
- else:
- _, preds_index = preds.max(2)
- preds_str = converter.decode(preds_index.data, preds_size.data)
-
- else:
- preds = model(image, text_for_pred, is_train=False)
- forward_time = time.time() - start_time
-
- preds = preds[:, :text_for_loss.shape[1] - 1, :]
- target = text_for_loss[:, 1:] # without [GO] Symbol
- cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
-
- # select max probabilty (greedy decoding) then decode index to character
- _, preds_index = preds.max(2)
- preds_str = converter.decode(preds_index, length_for_pred)
- labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
-
- infer_time += forward_time
- valid_loss_avg.add(cost)
-
- # calculate accuracy & confidence score
- preds_prob = F.softmax(preds, dim=2)
- preds_max_prob, _ = preds_prob.max(dim=2)
- confidence_score_list = []
- for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
- if 'Attn' in opt.Prediction:
- gt = gt[:gt.find('[s]')]
- pred_EOS = pred.find('[s]')
- pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
- pred_max_prob = pred_max_prob[:pred_EOS]
-
- # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
- if opt.sensitive and opt.data_filtering_off:
- pred = pred.lower()
- gt = gt.lower()
- alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
- out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
- pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
- gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)
-
- if pred == gt:
- n_correct += 1
-
- '''
- (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
- "For each word we calculate the normalized edit distance to the length of the ground truth transcription."
- if len(gt) == 0:
- norm_ED += 1
- else:
- norm_ED += edit_distance(pred, gt) / len(gt)
- '''
-
- # ICDAR2019 Normalized Edit Distance
- if len(gt) == 0 or len(pred) == 0:
- norm_ED += 0
- elif len(gt) > len(pred):
- norm_ED += 1 - edit_distance(pred, gt) / len(gt)
- else:
- norm_ED += 1 - edit_distance(pred, gt) / len(pred)
-
- # calculate confidence score (= multiply of pred_max_prob)
- try:
- confidence_score = pred_max_prob.cumprod(dim=0)[-1]
- except:
- confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
- confidence_score_list.append(confidence_score)
- # print(pred, gt, pred==gt, confidence_score)
-
- accuracy = n_correct / float(length_of_data) * 100
- norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
-
- return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
-
-
-def test(opt):
- """ model configuration """
- if 'CTC' in opt.Prediction:
- converter = CTCLabelConverter(opt.character)
- else:
- converter = AttnLabelConverter(opt.character)
- opt.num_class = len(converter.character)
-
- if opt.rgb:
- opt.input_channel = 3
- model = Model(opt)
- print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
- opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
- opt.SequenceModeling, opt.Prediction)
- model = torch.nn.DataParallel(model).to(device)
-
- # load model
- print('loading pretrained model from %s' % opt.saved_model)
- model.load_state_dict(torch.load(opt.saved_model, map_location=device))
- opt.exp_name = '_'.join(opt.saved_model.split('/')[1:])
- # print(model)
-
- """ keep evaluation model and result logs """
- os.makedirs(f'./result/{opt.exp_name}', exist_ok=True)
- os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/')
-
- """ setup loss """
- if 'CTC' in opt.Prediction:
- criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
- else:
- criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
-
- """ evaluation """
- model.eval()
- with torch.no_grad():
- if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets
- benchmark_all_eval(model, criterion, converter, opt)
- else:
- log = open(f'./result/{opt.exp_name}/log_evaluation.txt', 'a')
- AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
- eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt)
- evaluation_loader = torch.utils.data.DataLoader(
- eval_data, batch_size=opt.batch_size,
- shuffle=False,
- num_workers=int(opt.workers),
- collate_fn=AlignCollate_evaluation, pin_memory=True)
- _, accuracy_by_best_model, _, _, _, _, _, _ = validation(
- model, criterion, evaluation_loader, converter, opt)
- log.write(eval_data_log)
- print(f'{accuracy_by_best_model:0.3f}')
- log.write(f'{accuracy_by_best_model:0.3f}\n')
- log.close()
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--eval_data', required=True, help='path to evaluation dataset')
- parser.add_argument('--benchmark_all_eval', action='store_true', help='evaluate 10 benchmark evaluation datasets')
- parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
- parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
- parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation")
- """ Data processing """
- parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
- parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
- parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
- parser.add_argument('--rgb', action='store_true', help='use rgb input')
- parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
- parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
- parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
- parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode')
- parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
- """ Model Architecture """
- parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
- parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
- parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
- parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
- parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
- parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
- parser.add_argument('--output_channel', type=int, default=512,
- help='the number of output channel of Feature extractor')
- parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
-
- opt = parser.parse_args()
-
- """ vocab / character number configuration """
- if opt.sensitive:
- opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
-
- cudnn.benchmark = True
- cudnn.deterministic = True
- opt.num_gpu = torch.cuda.device_count()
-
- test(opt)
diff --git a/dtrb/train.py b/dtrb/train.py
deleted file mode 100644
index 7f3a66b..0000000
--- a/dtrb/train.py
+++ /dev/null
@@ -1,317 +0,0 @@
-import os
-import sys
-import time
-import random
-import string
-import argparse
-
-import torch
-import torch.backends.cudnn as cudnn
-import torch.nn.init as init
-import torch.optim as optim
-import torch.utils.data
-import numpy as np
-
-from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager
-from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
-from model import Model
-from test import validation
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-
-def train(opt):
- """ dataset preparation """
- if not opt.data_filtering_off:
- print('Filtering the images containing characters which are not in opt.character')
- print('Filtering the images whose label is longer than opt.batch_max_length')
- # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130
-
- opt.select_data = opt.select_data.split('-')
- opt.batch_ratio = opt.batch_ratio.split('-')
- train_dataset = Batch_Balanced_Dataset(opt)
-
- log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
- AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
- valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
- valid_loader = torch.utils.data.DataLoader(
- valid_dataset, batch_size=opt.batch_size,
- shuffle=True, # 'True' to check training progress with validation function.
- num_workers=int(opt.workers),
- collate_fn=AlignCollate_valid, pin_memory=True)
- log.write(valid_dataset_log)
- print('-' * 80)
- log.write('-' * 80 + '\n')
- log.close()
-
- """ model configuration """
- if 'CTC' in opt.Prediction:
- if opt.baiduCTC:
- converter = CTCLabelConverterForBaiduWarpctc(opt.character)
- else:
- converter = CTCLabelConverter(opt.character)
- else:
- converter = AttnLabelConverter(opt.character)
- opt.num_class = len(converter.character)
-
- if opt.rgb:
- opt.input_channel = 3
- model = Model(opt)
- print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
- opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
- opt.SequenceModeling, opt.Prediction)
-
- # weight initialization
- for name, param in model.named_parameters():
- if 'localization_fc2' in name:
- print(f'Skip {name} as it is already initialized')
- continue
- try:
- if 'bias' in name:
- init.constant_(param, 0.0)
- elif 'weight' in name:
- init.kaiming_normal_(param)
- except Exception as e: # for batchnorm.
- if 'weight' in name:
- param.data.fill_(1)
- continue
-
- # data parallel for multi-GPU
- model = torch.nn.DataParallel(model).to(device)
- model.train()
- if opt.saved_model != '':
- print(f'loading pretrained model from {opt.saved_model}')
- if opt.FT:
- model.load_state_dict(torch.load(opt.saved_model), strict=False)
- else:
- model.load_state_dict(torch.load(opt.saved_model))
- print("Model:")
- print(model)
-
- """ setup loss """
- if 'CTC' in opt.Prediction:
- if opt.baiduCTC:
- # need to install warpctc. see our guideline.
- from warpctc_pytorch import CTCLoss
- criterion = CTCLoss()
- else:
- criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
- else:
- criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
- # loss averager
- loss_avg = Averager()
-
- # filter that only require gradient decent
- filtered_parameters = []
- params_num = []
- for p in filter(lambda p: p.requires_grad, model.parameters()):
- filtered_parameters.append(p)
- params_num.append(np.prod(p.size()))
- print('Trainable params num : ', sum(params_num))
- # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]
-
- # setup optimizer
- if opt.adam:
- optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
- else:
- optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
- print("Optimizer:")
- print(optimizer)
-
- """ final options """
- # print(opt)
- with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
- opt_log = '------------ Options -------------\n'
- args = vars(opt)
- for k, v in args.items():
- opt_log += f'{str(k)}: {str(v)}\n'
- opt_log += '---------------------------------------\n'
- print(opt_log)
- opt_file.write(opt_log)
-
- """ start training """
- start_iter = 0
- if opt.saved_model != '':
- try:
- start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
- print(f'continue to train, start_iter: {start_iter}')
- except:
- pass
-
- start_time = time.time()
- best_accuracy = -1
- best_norm_ED = -1
- iteration = start_iter
-
- while(True):
- # train part
- image_tensors, labels = train_dataset.get_batch()
- image = image_tensors.to(device)
- text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
- batch_size = image.size(0)
-
- if 'CTC' in opt.Prediction:
- preds = model(image, text)
- preds_size = torch.IntTensor([preds.size(1)] * batch_size)
- if opt.baiduCTC:
- preds = preds.permute(1, 0, 2) # to use CTCLoss format
- cost = criterion(preds, text, preds_size, length) / batch_size
- else:
- preds = preds.log_softmax(2).permute(1, 0, 2)
- cost = criterion(preds, text, preds_size, length)
-
- else:
- preds = model(image, text[:, :-1]) # align with Attention.forward
- target = text[:, 1:] # without [GO] Symbol
- cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
-
- model.zero_grad()
- cost.backward()
- torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default)
- optimizer.step()
-
- loss_avg.add(cost)
-
- # validation part
- if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0'
- elapsed_time = time.time() - start_time
- # for log
- with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log:
- model.eval()
- with torch.no_grad():
- valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
- model, criterion, valid_loader, converter, opt)
- model.train()
-
- # training loss and validation loss
- loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
- loss_avg.reset()
-
- current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'
-
- # keep best accuracy model (on valid dataset)
- if current_accuracy > best_accuracy:
- best_accuracy = current_accuracy
- torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth')
- if current_norm_ED > best_norm_ED:
- best_norm_ED = current_norm_ED
- torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
- best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'
-
- loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
- print(loss_model_log)
- log.write(loss_model_log + '\n')
-
- # show some predicted results
- dashed_line = '-' * 80
- head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
- predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
- for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
- if 'Attn' in opt.Prediction:
- gt = gt[:gt.find('[s]')]
- pred = pred[:pred.find('[s]')]
-
- predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
- predicted_result_log += f'{dashed_line}'
- print(predicted_result_log)
- log.write(predicted_result_log + '\n')
-
- # save model per 1e+5 iter.
- if (iteration + 1) % 1e+5 == 0:
- torch.save(
- model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth')
-
- if (iteration + 1) == opt.num_iter:
- print('end the training')
- sys.exit()
- iteration += 1
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--exp_name', help='Where to store logs and models')
- parser.add_argument('--train_data', required=True, help='path to training dataset')
- parser.add_argument('--valid_data', required=True, help='path to validation dataset')
- parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting')
- parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
- parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
- parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for')
- parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation')
- parser.add_argument('--saved_model', default='', help="path to model to continue training")
- parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning')
- parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)')
- parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta')
- parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9')
- parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95')
- parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8')
- parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5')
- parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode')
- """ Data processing """
- parser.add_argument('--select_data', type=str, default='MJ-ST',
- help='select training data (default is MJ-ST, which means MJ and ST used as training data)')
- parser.add_argument('--batch_ratio', type=str, default='0.5-0.5',
- help='assign ratio for each selected data in the batch')
- parser.add_argument('--total_data_usage_ratio', type=str, default='1.0',
- help='total data usage ratio, this ratio is multiplied to total number of data.')
- parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
- parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
- parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
- parser.add_argument('--rgb', action='store_true', help='use rgb input')
- parser.add_argument('--character', type=str,
- default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
- parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
- parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
- parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode')
- """ Model Architecture """
- parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
- parser.add_argument('--FeatureExtraction', type=str, required=True,
- help='FeatureExtraction stage. VGG|RCNN|ResNet')
- parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
- parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
- parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
- parser.add_argument('--input_channel', type=int, default=1,
- help='the number of input channel of Feature extractor')
- parser.add_argument('--output_channel', type=int, default=512,
- help='the number of output channel of Feature extractor')
- parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
-
- opt = parser.parse_args()
-
- if not opt.exp_name:
- opt.exp_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}'
- opt.exp_name += f'-Seed{opt.manualSeed}'
- # print(opt.exp_name)
-
- os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True)
-
- """ vocab / character number configuration """
- if opt.sensitive:
- # opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
- opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
-
- """ Seed and GPU setting """
- # print("Random Seed: ", opt.manualSeed)
- random.seed(opt.manualSeed)
- np.random.seed(opt.manualSeed)
- torch.manual_seed(opt.manualSeed)
- torch.cuda.manual_seed(opt.manualSeed)
-
- cudnn.benchmark = True
- cudnn.deterministic = True
- opt.num_gpu = torch.cuda.device_count()
- # print('device count', opt.num_gpu)
- if opt.num_gpu > 1:
- print('------ Use multi-GPU setting ------')
- print('if you stuck too long time with multi-GPU setting, try to set --workers 0')
- # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1
- opt.workers = opt.workers * opt.num_gpu
- opt.batch_size = opt.batch_size * opt.num_gpu
-
- """ previous version
- print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size)
- opt.batch_size = opt.batch_size * opt.num_gpu
- print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.')
- If you dont care about it, just commnet out these line.)
- opt.num_iter = int(opt.num_iter / opt.num_gpu)
- """
-
- train(opt)
diff --git a/dtrb/utils.py b/dtrb/utils.py
deleted file mode 100644
index e576358..0000000
--- a/dtrb/utils.py
+++ /dev/null
@@ -1,169 +0,0 @@
-import torch
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-
-class CTCLabelConverter(object):
- """ Convert between text-label and text-index """
-
- def __init__(self, character):
- # character (str): set of the possible characters.
- dict_character = list(character)
-
- self.dict = {}
- for i, char in enumerate(dict_character):
- # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
- self.dict[char] = i + 1
-
- self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
-
- def encode(self, text, batch_max_length=25):
- """convert text-label into text-index.
- input:
- text: text labels of each image. [batch_size]
- batch_max_length: max length of text label in the batch. 25 by default
-
- output:
- text: text index for CTCLoss. [batch_size, batch_max_length]
- length: length of each text. [batch_size]
- """
- length = [len(s) for s in text]
-
- # The index used for padding (=0) would not affect the CTC loss calculation.
- batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0)
- for i, t in enumerate(text):
- text = list(t)
- text = [self.dict[char] for char in text]
- batch_text[i][:len(text)] = torch.LongTensor(text)
- return (batch_text.to(device), torch.IntTensor(length).to(device))
-
- def decode(self, text_index, length):
- """ convert text-index into text-label. """
- texts = []
- for index, l in enumerate(length):
- t = text_index[index, :]
-
- char_list = []
- for i in range(l):
- if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
- char_list.append(self.character[t[i]])
- text = ''.join(char_list)
-
- texts.append(text)
- return texts
-
-
-class CTCLabelConverterForBaiduWarpctc(object):
- """ Convert between text-label and text-index for baidu warpctc """
-
- def __init__(self, character):
- # character (str): set of the possible characters.
- dict_character = list(character)
-
- self.dict = {}
- for i, char in enumerate(dict_character):
- # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
- self.dict[char] = i + 1
-
- self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
-
- def encode(self, text, batch_max_length=25):
- """convert text-label into text-index.
- input:
- text: text labels of each image. [batch_size]
- output:
- text: concatenated text index for CTCLoss.
- [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
- length: length of each text. [batch_size]
- """
- length = [len(s) for s in text]
- text = ''.join(text)
- text = [self.dict[char] for char in text]
-
- return (torch.IntTensor(text), torch.IntTensor(length))
-
- def decode(self, text_index, length):
- """ convert text-index into text-label. """
- texts = []
- index = 0
- for l in length:
- t = text_index[index:index + l]
-
- char_list = []
- for i in range(l):
- if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
- char_list.append(self.character[t[i]])
- text = ''.join(char_list)
-
- texts.append(text)
- index += l
- return texts
-
-
-class AttnLabelConverter(object):
- """ Convert between text-label and text-index """
-
- def __init__(self, character):
- # character (str): set of the possible characters.
- # [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
- list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]']
- list_character = list(character)
- self.character = list_token + list_character
-
- self.dict = {}
- for i, char in enumerate(self.character):
- # print(i, char)
- self.dict[char] = i
-
- def encode(self, text, batch_max_length=25):
- """ convert text-label into text-index.
- input:
- text: text labels of each image. [batch_size]
- batch_max_length: max length of text label in the batch. 25 by default
-
- output:
- text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token.
- text[:, 0] is [GO] token and text is padded with [GO] token after [s] token.
- length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size]
- """
- length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
- # batch_max_length = max(length) # this is not allowed for multi-gpu setting
- batch_max_length += 1
- # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
- batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
- for i, t in enumerate(text):
- text = list(t)
- text.append('[s]')
- text = [self.dict[char] for char in text]
- batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
- return (batch_text.to(device), torch.IntTensor(length).to(device))
-
- def decode(self, text_index, length):
- """ convert text-index into text-label. """
- texts = []
- for index, l in enumerate(length):
- text = ''.join([self.character[i] for i in text_index[index, :]])
- texts.append(text)
- return texts
-
-
-class Averager(object):
- """Compute average for torch.Tensor, used for loss average."""
-
- def __init__(self):
- self.reset()
-
- def add(self, v):
- count = v.data.numel()
- v = v.data.sum()
- self.n_count += count
- self.sum += v
-
- def reset(self):
- self.n_count = 0
- self.sum = 0
-
- def val(self):
- res = 0
- if self.n_count != 0:
- res = self.sum / float(self.n_count)
- return res