Building a PDF OCR API with Flask and TrOCR
Introduction
Optical Character Recognition (OCR) has become an essential technology for digitizing documents, extracting text from images, and automating data entry. In this tutorial, we'll explore how to build a production-ready PDF OCR API using Flask, PyTorch, and Microsoft's TrOCR (Transformer-based OCR) model.
The project we'll be examining is a Flask-based web service that performs OCR on PDF files using machine vision and AI models. It's optimized for both CPU and GPU (NVIDIA CUDA) and provides a clean REST API interface.
Architecture Overview
The PDF OCR service follows a straightforward architecture:
- Flask Web Server: Handles HTTP requests and responses
- PDF Processing: Extracts pages from PDF files using PyMuPDF
- Image Processing: Segments images into text lines using OpenCV
- OCR Engine: Uses TrOCR models to extract text from image lines
- Response Formatting: Returns extracted text as JSON
Let's dive into each component.
Core Components
1. Flask Application Setup
The application uses Flask as the web framework. Here's how it's structured:
from flask import Flask, request, jsonify
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# Load default model at startup
default_model_name = "microsoft/trocr-base-printed"
processor = TrOCRProcessor.from_pretrained(default_model_name)
model = VisionEncoderDecoderModel.from_pretrained(default_model_name)
def create_app():
app = Flask(__name__)
# ... routes defined here
return app
app = create_app()Key Points:
- The TrOCR model and processor are loaded once at application startup, not per request. This is crucial for performance since loading these models is expensive.
- The
create_app()function follows the Flask application factory pattern, making the app testable and allowing multiple app instances.
2. Image Segmentation with OpenCV
Before we can extract text, we need to segment the PDF page images into individual lines. This is handled by the segment_lines() function:
def segment_lines(image, threshold_value=150, kernel_width=20, kernel_height=1, min_area=50):
"""Segments an image into lines based on provided image processing parameters."""
# Convert PIL image to grayscale numpy array
gray = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2GRAY)
# Apply binary thresholding (inverted)
# Pixels above threshold_value become white (255), others become black (0)
_, thresh = cv2.threshold(gray, threshold_value, 255, cv2.THRESH_BINARY_INV)
# Create a rectangular kernel for morphological operations
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_width, kernel_height))
# Morphological closing: fills gaps in text lines
# This connects nearby characters into continuous lines
thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
# Find contours (connected components) in the thresholded image
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Extract bounding boxes for each contour and crop lines
lines = []
for cnt in contours:
area = cv2.contourArea(cnt)
if area > min_area: # Filter out noise (small contours)
x, y, w, h = cv2.boundingRect(cnt)
line = image.crop((x, y, x + w, y + h))
lines.append(line)
return linesHow it works:
- Grayscale Conversion: Converts the color image to grayscale for simpler processing
- Thresholding: Creates a binary (black/white) image where text is white and background is black
- Morphological Closing: Uses a horizontal kernel to connect characters that might be slightly separated, forming continuous lines
- Contour Detection: Finds all connected white regions (potential text lines)
- Filtering: Removes small contours that are likely noise
- Cropping: Extracts each line as a separate image
Parameters Explained:
threshold_value(default: 150): Determines the brightness cutoff for text vs backgroundkernel_width(default: 20): Width of the morphological kernel - wider kernels connect more distant characterskernel_height(default: 1): Height of the kernel - kept small to only connect horizontallymin_area(default: 50): Minimum pixel area for a contour to be considered a valid text line
3. Text Extraction with TrOCR
The main text extraction function processes PDF files page by page:
def extract_text(file_stream, model_name=None, threshold_value=150,
kernel_width=20, kernel_height=1, min_area=50):
"""Extracts text from a PDF file using OCR."""
global processor, model
# Load a different model if specified
if model_name and model_name != default_model_name:
try:
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
except Exception as e:
raise Exception(f"Error loading model: {e}")
# Open PDF from file stream
doc = fitz.open(stream=file_stream.read(), filetype="pdf")
text = ""
# Use GPU if available, otherwise CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
# Process each page
for page in doc:
# Render page as an image (pixmap)
img = page.get_pixmap()
img_bytes = img.tobytes()
image = Image.open(io.BytesIO(img_bytes))
# Segment into lines
lines = segment_lines(image, threshold_value, kernel_width,
kernel_height, min_area)
# Extract text from each line
for line in lines:
# Preprocess image for the model
inputs = processor(images=line, return_tensors="pt").to(device)
# Generate text using the model
outputs = model.generate(**inputs)
# Decode the output tokens to text
text += processor.batch_decode(outputs, skip_special_tokens=True)[0] + "\n"
return textProcessing Flow:
- Model Loading: If a different model is specified, it loads that model (this is expensive, so it's cached globally)
- PDF Parsing: Uses PyMuPDF (fitz) to open the PDF from a byte stream
- Device Selection: Automatically uses CUDA GPU if available, otherwise falls back to CPU
- Page Rendering: Each PDF page is rendered as a pixmap (raster image)
- Line Segmentation: The page image is segmented into individual text lines
-
OCR Processing: Each line is processed through TrOCR:
- The processor normalizes and prepares the image
- The model generates text tokens
- Tokens are decoded back to readable text
- Text Aggregation: All lines are combined with newlines between them
TrOCR Model Architecture:
- Encoder: Vision Transformer (ViT) that processes the image
- Decoder: Transformer decoder that generates text tokens
- The model was pre-trained on millions of text images and fine-tuned for OCR
4. Flask API Endpoints
The application exposes two main endpoints:
OCR Endpoint
@app.route('/ocr', methods=['POST'])
def ocr():
"""Handles OCR processing for uploaded files."""
# Get file from form data
file = request.files['file']
# Get optional parameters with defaults
model_name = request.form.get('model', default_model_name)
threshold_value = int(request.form.get('threshold_value', 150))
kernel_width = int(request.form.get('kernel_width', 20))
kernel_height = int(request.form.get('kernel_height', 1))
min_area = int(request.form.get('min_area', 50))
try:
text = extract_text(
file,
model_name=model_name,
threshold_value=threshold_value,
kernel_width=kernel_width,
kernel_height=kernel_height,
min_area=min_area
)
return jsonify({'text': text})
except Exception as e:
return internal_server_error(e)Request Format:
- Method: POST
- Content-Type: multipart/form-data
- Required:
file(PDF file) -
Optional parameters:
model: TrOCR model name (default: "microsoft/trocr-base-printed")threshold_value: Image thresholding value (default: 150)kernel_width: Morphological kernel width (default: 20)kernel_height: Morphological kernel height (default: 1)min_area: Minimum contour area (default: 50)
Response Format:
{
"text": "Extracted text from the PDF...\n"
}Models Endpoint
@app.route('/models', methods=['GET'])
def list_models():
"""Lists all supported OCR models."""
models = get_supported_models()
return jsonify({'supported_models': models})
def get_supported_models():
"""Returns list of supported TrOCR models."""
return [
"microsoft/trocr-large-handwritten",
"microsoft/trocr-large-printed",
"microsoft/trocr-small-printed",
"microsoft/trocr-small-handwritten",
"microsoft/trocr-base-handwritten",
"microsoft/trocr-base-printed",
"microsoft/trocr-base-stage1",
"microsoft/trocr-large-stage1"
]5. Error Handling
The application includes comprehensive error handling:
@app.errorhandler(400)
def bad_request(error):
return jsonify({'error': 'Bad request', 'details': str(error)}), 400
@app.errorhandler(500)
def internal_server_error(error):
return jsonify({'error': 'Internal server error', 'details': str(error)}), 500Docker Deployment
The project includes a multi-stage Dockerfile for efficient deployment:
# First stage: build and install dependencies
FROM python:3.12-slim as builder
WORKDIR /app
ENV PATH="/app/venv/bin:$PATH"
RUN apt-get update && apt-get install -y --no-install-recommends build-essential \
&& rm -rf /var/lib/apt/lists/* \
&& python -m venv venv
COPY src/*.txt ./
RUN pip install --no-cache-dir --timeout=120 -r requirements.txt
# Second stage: create the final image
FROM python:3.12-slim
WORKDIR /app
COPY /app/venv /app/venv
ENV PATH="/app/venv/bin:$PATH"
COPY src/ ./
EXPOSE 5000
CMD ["gunicorn", "src.app:app", "-w", "4", "-b", "0.0.0.0:5000"]Key Features:
- Multi-stage build: Reduces final image size by excluding build tools
- Virtual environment: Isolates dependencies
- Gunicorn: Production WSGI server with 4 worker processes
- Optimized layers: Dependencies are installed in a separate layer for better caching
Usage Example
Here's how to use the API:
# Using curl
curl -X POST http://localhost:5000/ocr \
-F "file=@document.pdf" \
-F "model=microsoft/trocr-base-printed" \
-F "threshold_value=150"
# Using Python requests
import requests
with open('document.pdf', 'rb') as f:
response = requests.post(
'http://localhost:5000/ocr',
files={'file': f},
data={
'model': 'microsoft/trocr-base-printed',
'threshold_value': 150,
'kernel_width': 20,
'min_area': 50
}
)
result = response.json()
print(result['text'])Model Selection Guide
Different TrOCR models are optimized for different use cases:
trocr-base-printed: Good balance of speed and accuracy for printed texttrocr-large-printed: Higher accuracy for printed text, slowertrocr-base-handwritten: Optimized for handwritten texttrocr-large-handwritten: Best accuracy for handwritten texttrocr-small-*: Faster inference, lower accuracy
Performance Considerations
- Model Loading: Models are loaded once at startup, not per request
- GPU Acceleration: Automatically uses CUDA if available
- Batch Processing: Consider batching multiple lines for better GPU utilization
- Caching: Processed results could be cached for identical documents
- Async Processing: For large PDFs, consider using background tasks (Celery, RQ)
Dependencies Breakdown
- Flask: Web framework for the API
- PyMuPDF (fitz): PDF parsing and rendering
- transformers: Hugging Face library for TrOCR models
- torch: PyTorch for deep learning inference
- opencv-python-headless: Image processing (headless = no GUI dependencies)
- Pillow: Image manipulation
- sentencepiece: Text tokenization for the model
- gunicorn: Production WSGI server
Conclusion
This PDF OCR service demonstrates how to combine traditional computer vision techniques (image segmentation) with modern transformer-based AI models to create a practical OCR solution. The architecture is modular, allowing easy customization of image processing parameters and model selection based on your specific use case.
The key takeaway is the two-stage approach: first segmenting the image into lines using OpenCV, then processing each line with TrOCR. This approach works well because:
- It handles multi-line text naturally
- It's more accurate than processing entire pages
- It allows fine-tuning of image preprocessing parameters
For production use, consider adding:
- Request rate limiting
- Authentication/authorization
- Result caching
- Async processing for large files
- Monitoring and logging
- Health check endpoints