Feature Extraction
Transformers
Safetensors
vision-encoder-decoder
custom_code
anicolson commited on
Commit
e8acc55
·
verified ·
1 Parent(s): 95af3f2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +13 -4
README.md CHANGED
@@ -21,7 +21,7 @@ EAST was applied to a multimodal language model with RadGraph as the reward. Oth
21
  - Special tokens (`NF` and `NI`) to handle missing *findings* and *impression* sections.
22
  - Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings.
23
 
24
- ## How to use:
25
 
26
  ```python
27
  import torch
@@ -42,14 +42,23 @@ transforms = v2.Compose(
42
  ]
43
  )
44
 
45
- image = transforms(image) # Fix.
 
 
 
 
 
 
 
 
 
46
 
47
  output_ids = model.generate(
48
- pixel_values=images, # Fix.
49
  max_length=512,
50
- bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
51
  num_beams=4,
52
  use_cache=True,
 
53
  )
54
  findings, impression = model.split_and_decode_sections(output_ids, tokenizer)
55
  ```
 
21
  - Special tokens (`NF` and `NI`) to handle missing *findings* and *impression* sections.
22
  - Non-causal attention masking for the image embeddings and a causal attention masking for the report token embeddings.
23
 
24
+ ## Example:
25
 
26
  ```python
27
  import torch
 
42
  ]
43
  )
44
 
45
+ dataset = datasets.load_dataset('StanfordAIMI/interpret-cxr-test-public')['test']
46
+
47
+ def transform_batch(batch):
48
+ batch['images'] = [torch.stack([transforms(j) for j in i]) for i in batch['images']]
49
+ batch['images'] = torch.nn.utils.rnn.pad_sequence(batch['images'], batch_first=True, padding_value=0.0)
50
+ return batch
51
+
52
+ dataset = dataset.with_transform(transform_batch)
53
+ dataloader = DataLoader(dataset, batch_size=mbatch_size, shuffle=True)
54
+ batch = next(iter(dataloader))
55
 
56
  output_ids = model.generate(
57
+ pixel_values=batch['images'],
58
  max_length=512,
 
59
  num_beams=4,
60
  use_cache=True,
61
+ bad_words_ids=[[tokenizer.convert_tokens_to_ids('[NF]')], [tokenizer.convert_tokens_to_ids('[NI]')]],
62
  )
63
  findings, impression = model.split_and_decode_sections(output_ids, tokenizer)
64
  ```