TJKlein commited on
Commit
befb6c9
·
1 Parent(s): 385926e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +78 -2
README.md CHANGED
@@ -13,7 +13,12 @@ The **miCSE** language model is trained for sentence similarity computation. Tra
13
  The model intended to be used for encoding sentences or short paragraphs. Given an input text, the model produces a vector embedding, which captures the semantics. The embedding can be used for numerous tasks, e.g., **retrieval**, **clustering** or **sentence similarity** comparison (see example below). Sentence representations correspond to the embedding of the _**[CLS]**_ token.
14
 
15
 
 
 
 
 
16
  # Model Usage
 
17
 
18
  ```python
19
  from transformers import AutoTokenizer, AutoModel
@@ -61,9 +66,80 @@ print(f"Distance: {cos_sim[0,1].detach().item()}")
61
 
62
  ```
63
 
64
- # Training data
65
 
66
- The model was trained on a random collection of **English** sentences from Wikipedia: [Training data file](https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # Benchmark
69
 
 
13
  The model intended to be used for encoding sentences or short paragraphs. Given an input text, the model produces a vector embedding, which captures the semantics. The embedding can be used for numerous tasks, e.g., **retrieval**, **clustering** or **sentence similarity** comparison (see example below). Sentence representations correspond to the embedding of the _**[CLS]**_ token.
14
 
15
 
16
+ # Training data
17
+
18
+ The model was trained on a random collection of **English** sentences from Wikipedia: [Training data file](https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt)
19
+
20
  # Model Usage
21
+ ## Example 1) - Sentence Similarity
22
 
23
  ```python
24
  from transformers import AutoTokenizer, AutoModel
 
66
 
67
  ```
68
 
69
+ ## Example 2) - Clustering
70
 
71
+ ```python
72
+ from transformers import AutoTokenizer, AutoModel
73
+ import torch.nn as nn
74
+ import torch
75
+ import numpy as np
76
+ import tqdm
77
+ from datasets import load_dataset
78
+ import umap
79
+ import umap.plot as umap_plot
80
+
81
+ # Determine available hardware
82
+ if torch.backends.mps.is_available():
83
+ device = torch.device("mps")
84
+ elif torch.cuda.is_available():
85
+ device = torch.device("gpu")
86
+ else:
87
+ device = torch.device("cpu")
88
+
89
+ # Load tokenizer and model
90
+ tokenizer = AutoTokenizer.from_pretrained("/Users/d065243/miCSE")
91
+ model = AutoModel.from_pretrained("/Users/d065243/miCSE")
92
+
93
+ # Load Twitter data for sentiment clustering
94
+ dataset = load_dataset("tweet_eval", "sentiment")
95
+
96
+
97
+ # Compute embeddings of the tweets
98
+
99
+ # set batch size and maxium tweet token length
100
+ batch_size = 50
101
+ max_length = 128
102
+
103
+ iterations = int(np.floor(len(dataset['train'])/batch_size))*batch_size
104
+
105
+ embedding_stack = []
106
+ classes = []
107
+ for i in tqdm.notebook.tqdm(range(0,iterations,batch_size)):
108
+ # create batch
109
+ batch = tokenizer.batch_encode_plus(
110
+ dataset['train'][i:i+batch_size]['text'],
111
+ return_tensors='pt',
112
+ padding=True,
113
+ max_length=max_length,
114
+ truncation=True
115
+ ).to(device)
116
+ classes = classes + dataset['train'][i:i+batch_size]['label']
117
+
118
+ # model inference without gradient
119
+ with torch.no_grad():
120
+ outputs = model(**batch, output_hidden_states=True, return_dict=True)
121
+
122
+ embeddings = outputs.last_hidden_state[:,0]
123
+
124
+
125
+ embedding_stack.append( embeddings.cpu().clone() )
126
+
127
+ embeddings = torch.vstack(embedding_stack)
128
+
129
+
130
+ # Cluster embeddings in 2D with UMAP
131
+ umap_model = umap.UMAP(n_neighbors=250,
132
+ n_components=2,
133
+ min_dist=1.0e-9,
134
+ low_memory=True,
135
+ angular_rp_forest=True,
136
+ metric='cosine')
137
+ umap_model.fit(embeddings)
138
+
139
+ # Plot result
140
+ umap_plot.points(umap_model, labels = np.array(classes),theme='fire')
141
+
142
+ ```
143
 
144
  # Benchmark
145