作者:Sovit Rath
编译:ronghuaiyang
导读
本文介绍了TrOCR的结构和使用方法,手把手从每一行代码教起。
光学字符识别(OCR)在过去几年中出现了一些创新。它对零售、医疗、银行和许多其他行业的影响是巨大的。尽管有着悠久的历史和一些最先进的模型,研究人员仍在不断创新。与深度学习的许多其他领域一样,OCR也看到了transformer 经网络的重要性和影响。今天,我们有像TrOCR(Transformer OCR)这样的模型,它们在准确性方面确实超过了以前的技术。
在本文中,我们将介绍TrOCR,并重点讨论四个主题:
TrOCR由李等人在论文TrOCR:Transformer-based Optical Character Recognition with Pre-trained Models中介绍。
作者提出了一种背离传统的CNN和RNN的方法,他们使用视觉和语言transformer 模型来构建TrOCR架构。
TrOCR模型由两个阶段组成:
由于其高效的预训练,基于transformer的模型在下游任务中表现得非常好。因此,作者选择了DeIT作为视觉转换器模型。对于解码器阶段,他们选择了RoBERTa或UniLM模型,这取决于TrOCR变体。
下图显示了使用TrOCR的简单OCR pipeline。
在上图中,左块显示视觉transformer 编码器,右块显示语言transformer 解码器。以下是TrOCR推理阶段的简单分解:
需要注意的一点是,在进入视觉transformer模型之前,图像的大小调整为384×384分辨率。这是因为DeIT模型期望图像具有特定的大小。
TrOCR 家族模型包括几个预训练和微调模型。
TrOCR家族中的预训练模型叫做第一阶段模型,这些模型在大量的生成数据上进行训练。数据集中包括百万张打印的文本行图像。
官方的代码仓库中包含了3个不同大小的预训练模型:
越大的模型效果越好,但是越慢。
在预训练步骤之后,模型在IAM手写数据文本图像和SROIE打印收据数据集上进行微调。
IAM手写数据集包含了手写文本,在这个数据集上进行微调使得这个模型在手写文本的效果上好于其他模型。
类似的,SROIE数据集包含了几千个收据图像样本,微调之后,在打印文本上的效果会表现很好。
和预训练步骤的模型一样,手写和打印模型也包含了3个不同大小的模型:
Hugging Face上有TrOCR的所有模型,包括预训练步骤和微调步骤的。
我们会使用2个模型,一个手写微调模型,一个打印微调模型,来进行推理实验。
在Hugging Face上,模型的命名遵循trocr-<model_scale>-<training_stage>规则。
举例说明,在IAM手写数据集训练的TrOCR的小模型叫做trocr-small-handwritten。
我们使用trocr-small-printed和trocr-base-handwritten来进行推理。
首先要安装一些库:Hugging Face transformers, sentencepiece tokenizer.
!pip install -q transformers!pip install -q -U sentencepiece
然后是下面的导入语句:
from transformers import TrOCRProcessor, VisionEncoderDecoderModelfrom PIL import Imagefrom tqdm.auto import tqdmfrom urllib.request import urlretrievefrom zipfile import ZipFile import numpy as npimport matplotlib.pyplot as pltimport torchimport osimport glob
我们需要用到urllib和zipfile 来解压推理数据。
前向过程使用GPU和CPU都可以。
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
下面的函数是如何下载和解压数据集。
def download_and_unzip(url, save_path): print(f"Downloading and extracting assets....", end="") # Downloading zip file using urllib package. urlretrieve(url, save_path) try: # Extracting zip file using the zipfile package. with ZipFile(save_path) as z: # Extract ZIP file contents in the same directory. z.extractall(os.path.split(save_path)[0]) print("Done") except Exception as e: print("\nInvalid file.", e) URL = r"https://www.dropbox.com/scl/fi/jz74me0vc118akmv5nuzy/images.zip?rlkey=54flzvhh9xxh45czb1c8n3fp3&dl=1"asset_zip_path = os.path.join(os.getcwd(), "images.zip")# Download if assest ZIP does not exists.if not os.path.exists(asset_zip_path): download_and_unzip(URL, asset_zip_path)
上面的代码下载的图像包括:
接下来,我们用一个简单的函数来读取图像。
def read_image(image_path): """ :param image_path: String, path to the input image. Returns: image: PIL Image. """ image = Image.open(image_path).convert('RGB') return image
read_image() 函数的参数为图像路径,返回RGB格式的图像。
我们还写了一个函数来实现OCR的pipeline。
def ocr(image, processor, model): """ :param image: PIL Image. :param processor: Huggingface OCR processor. :param model: Huggingface OCR model. Returns: generated_text: the OCR'd text string. """ # We can directly perform OCR on cropped images. pixel_values = processor(image, return_tensors='pt').pixel_values.to(device) generated_ids = model.generate(pixel_values) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text
ocr()这个函数需要下面几个参数:
在返回语句之前,有一个batch_decode() 函数,这个实际上就是将生成的IDs转换为输出文本,skip_special_tokens=True表示我们在输出中不需要特殊tokens,比如结束符和开始符。
最后的这个函数用来运行推理新的图像,包括了之前的函数,并显示了输出结果。
def eval_new_data(data_path=None, num_samples=4, model=None): image_paths = glob.glob(data_path) for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)): if i == num_samples: break image = read_image(image_path) text = ocr(image, processor, model) plt.figure(figsize=(7, 4)) plt.imshow(image) plt.title(text) plt.axis('off') plt.show()
eval_new_data() 这个函数的参数为文件夹路径,样本数量,以及模型。
我们加载TrOCR processor和模型来进行打印文本识别。
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-printed')model = VisionEncoderDecoderModel.from_pretrained( 'microsoft/trocr-small-printed').to(device)
要加载TrOCR processor,我们需要使用from_pretrained模块,该模块接收HuggingFace的仓库路径,包含特定的模块。
TrOCR Processor做了哪些事情?
TrOCR模型是一个神经网络,不能直接处理图像,我们需要先将图像处理成合适的格式。TrOCR processor首先将图像缩放到 384×384 的分辨率,然后转换为归一化的tensor格式,然后再进行模型的推理。我们还可以指定tensort的格式,比如,我们转化为pt格式,表示是pytorch的tensor,我们还可以得到TensorFlow的格式。
同样的,我们使用VisionEncoderDecoderModel类加载预训练模型,在上面的代码中,我们加载了trocr-small-printed 模型,并将其加载到设备中。然后,我们调用eval_new_data()函数开始推理。
eval_new_data( data_path=os.path.join('images', 'newspaper', '*'), num_samples=2, model=model)
运行上面的代码可以得到下面的输出:
图像上的文本表示了模型的输出,模型在模糊图像上的表现也很好,在第一张图像上,模型可以预测出所有的标点符号,空格,甚至是破折号。
对于手写文本的推理,我们使用基础模型(比小模型大),我们首先加载手写TrOCR processor和模型。
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')model = VisionEncoderDecoderModel.from_pretrained( 'microsoft/trocr-base-handwritten').to(device)
我们的方法和打印文本模型一样,只是把仓库地址该成需要的模型。
在运行推理时,我们需要改变数据路径。
eval_new_data( data_path=os.path.join('images', 'handwritten', '*'), num_samples=2, model=model)
这里是输出:
这个例子很好的表现了TrOCR 在手写文本上的效果,可以正确识别出所有的字符,甚至是连写的字符。
对于不同的手写风格,模型的效果也很好。视觉和语言模型的组合的威力显现。
TrOCR并不是在所有类型的图像上都能表现很好。举例说明,小模型在弯曲文本上的效果不好,下面是几个例子:
很明显,模型无法理解和识别出STATES 这个词,输出的是<。
这是另外一个例子:
这次,模型能预测出一个词,但是是错误的。
从上面可以看到,TrOCR模型可能在某些场景下表现不好,这种限制同时来自视觉transformer和语言transformer的能力限制,需要一个经过弯曲文本图像训练过的视觉模型,以及能理解这种token的语言模型。
最好的方法是在弯曲文本数据集上微调 TrOCR 模型,我们会在SCUT-CTW1500数据集上进行训练。
OCR 用简单的架构已经发展了很长时间,如今,TrOCR 为该领域带来了新的可能性。我们从介绍 TrOCR 开始,深入研究了它的架构。在此之后,我们介绍了不同的 TrOCR 模型及其训练策略。我们通过运行推理和分析结果来完成本文。
一个简单而有效的应用可以是数字化旧文章和报纸,这些文章和报纸很难人工清晰易读。
但是,在处理弯曲文本和自然场景中的文本时,TrOCR 也有其局限性。我们将在下一篇文章中更深入地探讨这一点,我们将在弯曲文本数据集上微调 TrOCR 模型并解锁新功能。
—END—
英文原文:https://learnopencv.com/trocr-getting-started-with-transformer-based-ocr/
评论列表 (0条)