metadata
library_name: transformers
license: apache-2.0
datasets:
- StanfordAIMI/interpret-cxr-test-public
- StanfordAIMI/interpret-cxr-test-hidden
CXRMate-RRG4: Entropy-Augmented Self-Critical Sequence Training for Radiology Report Generation
This is an evolution of https://huggingface.co/aehrc/cxrmate developed for the Radiology Report Generation task of BioNLP @ ACL 2024.
For this, we proposed EAST: Entropy-Augmented Self-critical sequence Training (EAST):
- EAST modifies Self-Critical Sequence Training (SCST) by adding entropy regularisation.
- Helps maintain a higher entropy in the token distribution.
- Preventing overfitting to common phrases and ensuring a broader exploration of the vocabulary during training.
- This was essential to handle the diversity of the radiology reports in the RRG24 datasets.
EAST was applied to a multimodal language model with RadGraph as the reward. Other features include:
- Token type embeddings to differentiate between findings and impression section tokens, as well as image embeddings.
- Special tokens (
NF
andNI
) to handle missing findings and impression sections. - Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings.
How to use:
import torch
from torchvision.transforms import v2
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained('aehrc/cxrmate-rrg24')
model = transformers.AutoModel.from_pretrained('aehrc/cxrmate-rrg24', trust_remote_code=True)
transforms = v2.Compose(
[
v2.PILToTensor(),
v2.Grayscale(num_output_channels=3),
v2.Resize(size=model.config.encoder.image_size, antialias=True),
v2.CenterCrop(size=[model.config.encoder.image_size]*2),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=model.config.encoder.image_mean, std=model.config.encoder.image_std),
]
)
image = transforms(image) # Fix.
output_ids = model.generate(
pixel_values=images, # Fix.
max_length=512,
bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
num_beams=4,
use_cache=True,
)
findings, impression = model.split_and_decode_sections(output_ids, tokenizer)
Notebook example:
https://huggingface.co/aehrc/cxrmate-rrg24/blob/main/demo.ipynb
Paper:
Citation:
[More Information Needed]