davidheineman commited on
Commit
bccc6f8
·
1 Parent(s): 72131b0

remove all files but the indices

Browse files
.gitattributes DELETED
@@ -1,39 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- *.bib filter=lfs diff=lfs merge=lfs -text
37
- collection.json filter=lfs diff=lfs merge=lfs -text
38
- dataset.json filter=lfs diff=lfs merge=lfs -text
39
- index/metadata.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore DELETED
@@ -1,4 +0,0 @@
1
- __pycache__
2
- experiments
3
- .DS_Store
4
- app
 
 
 
 
 
Dockerfile DELETED
@@ -1,18 +0,0 @@
1
- FROM python:3.10
2
-
3
- WORKDIR /app
4
-
5
- COPY requirements.txt .
6
- RUN pip install --no-cache-dir -r requirements.txt
7
-
8
- COPY . .
9
-
10
- # Copy ColBERT files that aren't downloaded properly
11
- COPY ./src/extras/segmented_maxsim.cpp /usr/local/lib/python3.10/site-packages/colbert/modeling/segmented_maxsim.cpp
12
- COPY ./src/extras/decompress_residuals.cpp /usr/local/lib/python3.10/site-packages/colbert/search/decompress_residuals.cpp
13
- COPY ./src/extras/filter_pids.cpp /usr/local/lib/python3.10/site-packages/colbert/search/filter_pids.cpp
14
- COPY ./src/extras/segmented_lookup.cpp /usr/local/lib/python3.10/site-packages/colbert/search/segmented_lookup.cpp
15
-
16
- # CMD ["sh", "-c", "sleep infinity"]
17
- CMD ["python", "src/server.py"]
18
- # CMD ["gunicorn", "-w", "4", "-b", "0.0.0.0:8893", "src/server:app"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -2,110 +2,6 @@
2
  license: apache-2.0
3
  ---
4
 
5
- Use ColBERT as a search engine for the [ACL Anthology](https://aclanthology.org/). (Parse any bibtex, and store in a MySQL service)
6
 
7
- # Setup
8
-
9
- ## Setup ColBERT
10
- ```sh
11
- git clone https://huggingface.co/davidheineman/colbert-acl
12
-
13
- # install dependencies
14
- # torch==1.13.1 required (conda install -y -n [env] python=3.10)
15
- pip install -r requirements.txt
16
- brew install mysql
17
- ```
18
-
19
- ### (Optional) Parse & Index the Anthology
20
-
21
- Feel free to skip, since the parsed/indexed anthology is contained in this repo.
22
-
23
- ```sh
24
- # get up-to-date abstracts in bibtex
25
- curl -O https://aclanthology.org/anthology+abstracts.bib.gz
26
- gunzip anthology+abstracts.bib.gz
27
- mv anthology+abstracts.bib anthology.bib
28
-
29
- # parse .bib -> .json
30
- python parse.py
31
-
32
- # index with ColBERT
33
- # (note sometimes there is a silent failure if the CPP extensions do not exist)
34
- python index.py
35
- ```
36
-
37
- ### Search with ColBERT
38
-
39
- ```sh
40
- # start flask server
41
- python server.py
42
-
43
- # or start a production API endpoint
44
- gunicorn -w 4 -b 0.0.0.0:8893 server:app
45
- ```
46
-
47
- Then, to test, visit:
48
- ```
49
- http://localhost:8893/api/search?query=Information retrevial with BERT
50
- ```
51
- or for an interface:
52
- ```
53
- http://localhost:8893
54
- ```
55
-
56
- ### Deploy as a Docker App
57
- ```sh
58
- docker-compose build --no-cache
59
- docker-compose up --build
60
- ```
61
-
62
- ## Example notebooks
63
-
64
- To see an example of search, visit:
65
- [colab.research.google.com/drive/1-b90_8YSAK17KQ6C7nqKRYbCWEXQ9FGs](https://colab.research.google.com/drive/1-b90_8YSAK17KQ6C7nqKRYbCWEXQ9FGs?usp=sharing)
66
-
67
- <!-- ## Notes
68
- - See:
69
- - https://github.com/stanford-futuredata/ColBERT/blob/main/colbert/index_updater.py
70
- - https://github.com/stanford-futuredata/ColBERT/issues/111
71
-
72
- - TODO:
73
- - On UI
74
- - Colors: make the colors resemble the ACL page much closer
75
- - There's still a bunch of blue from the bootstrap themeing
76
- - Smaller line spacing for abstract text
77
- - Add "PDF" button
78
- - Justify the result metadata (Year, venue, etc.) so the content all starts at the same vertical position
79
- - Add a "Expand" button at the end of the abstract
80
- - Make the results scrollable, without scrolling the rest of the page
81
- - Put two sliders on the year range (and make the years selectable, with the years at both ends of the bar)
82
- - If the user selects certain venues, remember these venues
83
- - Add a dropdown under the "Workshop" box to select specific workshops
84
-
85
- - Include the title in the indexing
86
-
87
- - https://docs.docker.com/language/python/configure-ci-cd/
88
-
89
- - Have articles before 2020
90
-
91
- - Maybe make the UI more compressed like this: https://aclanthology.org/events/eacl-2024/#2024eacl-long
92
-
93
- - Put query in URL (?q=XXX)
94
-
95
- - Move code to github and index to hf, then use this to download the index:
96
- from huggingface_hub import snapshot_download
97
-
98
- # Download indexed repo at: https://huggingface.co/davidheineman/colbert-acl
99
- !mkdir "acl"
100
- index_name = snapshot_download(repo_id="davidheineman/colbert-acl", local_dir="acl")
101
- - Make indexing much easier
102
- (currently, the setup involves manually copying the CPP files becuase there is a silent failure, this also should be possible to do on Google Collab, or even MPS)
103
- - Make index save in parent folder
104
- - Fix "sanity check" in index.py
105
- - Profile bibtexparser.load(f) (why so slow)
106
- - Ship as a containerized service
107
- - Scrape:
108
- - https://proceedings.neurips.cc/
109
- - https://dblp.uni-trier.de/db/conf/iclr/index.html
110
- - openreview
111
- -->
 
2
  license: apache-2.0
3
  ---
4
 
5
+ Use ColBERT as a search engine for the [ACL Anthology](https://aclanthology.org/). (Parse any bibtex, and store in a MySQL service).
6
 
7
+ **This repo contains the ColBERT index and dataset.** To run the interface, see [github.com/davidheineman/acl-search](https://github.com/davidheineman/acl-search).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docker-compose.yml DELETED
@@ -1,33 +0,0 @@
1
- services:
2
- mysql:
3
- image: mysql:8.0
4
- container_name: mysql_db
5
- environment:
6
- MYSQL_ROOT_PASSWORD:
7
- MYSQL_ALLOW_EMPTY_PASSWORD: true
8
- MYSQL_DATABASE: anthology
9
- MYSQL_USER: myuser
10
- MYSQL_PASSWORD: mysecret
11
- volumes:
12
- - mysql_data:/var/lib/mysql
13
- networks:
14
- - mysql_network
15
-
16
- python:
17
- build:
18
- context: .
19
- dockerfile: Dockerfile
20
- container_name: python_app
21
- ports:
22
- - "8893:8893" # host:local
23
- depends_on:
24
- - mysql
25
- networks:
26
- - mysql_network
27
-
28
- networks:
29
- mysql_network:
30
- driver: bridge
31
-
32
- volumes:
33
- mysql_data:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,7 +0,0 @@
1
- torch==1.13.1
2
- colbert-ir[torch,faiss-cpu] # faiss-gpu
3
- faiss-cpu # shouldn't have to include
4
- bibtexparser
5
- mysql-connector-python
6
- flask
7
- gunicorn
 
 
 
 
 
 
 
 
src/constants.py DELETED
@@ -1,15 +0,0 @@
1
- import os
2
- from typing import Literal
3
-
4
- INDEX_NAME = os.getenv("INDEX_NAME", 'index')
5
- INDEX_ROOT = os.getenv("INDEX_ROOT", os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
-
7
- INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
8
- ANTHOLOGY_PATH = os.path.join(INDEX_ROOT, 'anthology.bib')
9
- DATASET_PATH = os.path.join(INDEX_ROOT, 'dataset.json')
10
-
11
- DB_NAME = 'anthology'
12
- DB_HOSTNAME = 'mysql_db' # localhost
13
- DB_PORT = 3306 # None
14
-
15
- VENUES = Literal['workshop', 'journal', 'short', 'demo', 'tutorial', 'industry', 'findings', 'main']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/db.py DELETED
@@ -1,166 +0,0 @@
1
- import json
2
- from typing import List, Optional, Union
3
-
4
- import mysql.connector
5
- from constants import DATASET_PATH, DB_HOSTNAME, DB_NAME, DB_PORT, VENUES
6
-
7
- PAPER_QUERY = """
8
- SELECT *
9
- FROM paper
10
- WHERE pid IN ({query_arg_str}){constraints_str};
11
- """
12
-
13
- def read_dataset():
14
- print("Reading dataset...")
15
- with open(DATASET_PATH, 'r', encoding='utf-8') as f:
16
- dataset = json.loads(f.read())
17
- dataset = [d for d in dataset if 'abstract' in d.keys()]
18
- return dataset
19
-
20
-
21
- def create_database():
22
- db = mysql.connector.connect(
23
- host = DB_HOSTNAME,
24
- user = "root",
25
- password = "",
26
- port = DB_PORT
27
- )
28
- cursor = db.cursor()
29
-
30
- cursor.execute("SHOW DATABASES")
31
- db_exists = False
32
- for x in cursor:
33
- db_name = x[0]
34
- if db_name == DB_NAME:
35
- db_exists = True
36
-
37
- # Create database
38
- if not db_exists:
39
- print("Creating new database...")
40
- cursor.execute(f'CREATE DATABASE {DB_NAME}')
41
- cursor.execute(f'USE {DB_NAME}')
42
-
43
- # Create table
44
- print('Creating new table...')
45
- cursor.execute(f'DROP TABLE IF EXISTS paper')
46
- cursor.execute("""
47
- CREATE TABLE paper (
48
- pid INT PRIMARY KEY,
49
- title VARCHAR(1024),
50
- author VARCHAR(2170),
51
- year INT,
52
- abstract TEXT(12800),
53
- url VARCHAR(150),
54
- type VARCHAR(100),
55
- venue VARCHAR(500),
56
- venue_type VARCHAR(150),
57
- is_findings TINYINT(1) NOT NULL DEFAULT 0
58
- )
59
- """)
60
-
61
- acl_data = read_dataset()
62
-
63
- vals = []
64
- for pid, paper in enumerate(acl_data):
65
- title = paper.get('title', '')
66
- author = paper.get('author', '')
67
- year = paper.get('year', '')
68
- abstract = paper.get('abstract', '')
69
- url = paper.get('url', '')
70
- type = paper.get('ENTRYTYPE', '')
71
- venue = paper.get('booktitle', '')
72
- venue_type = paper.get('venue_type', '')
73
- is_findings = paper.get('is_findings', '0')
74
-
75
- if not abstract: continue
76
-
77
- vals += [(pid, title, author, year, abstract, url, type, venue, venue_type, is_findings)]
78
-
79
- sql = """
80
- INSERT INTO paper (
81
- pid, title, author, year, abstract, url, type, venue, venue_type, is_findings
82
- ) VALUES (
83
- %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
84
- )
85
- """
86
-
87
- print('Writing entries to table...')
88
- cursor.executemany(sql, vals)
89
- db.commit()
90
-
91
-
92
- def parse_results(results):
93
- parsed_results = {}
94
-
95
- for result in results:
96
- pid, title, authors, year, abstract, url, type, venue, venue_type, is_findings = result
97
-
98
- title = title.replace("{", "").replace("}", "")
99
- authors = authors.replace("{", "").replace("}", "").replace('\\"', "")
100
- abstract = abstract.replace("{", "").replace("}", "").replace("\\", "")
101
-
102
- parsed_results[int(pid)] = {
103
- 'title': title,
104
- 'authors': authors,
105
- 'year': year,
106
- 'abstract': abstract,
107
- 'url': url,
108
- 'type': type,
109
- 'venue': venue,
110
- 'venue_type': venue_type,
111
- 'is_findings': is_findings,
112
- }
113
-
114
- return parsed_results
115
-
116
-
117
- def query_paper_metadata(
118
- pids: List[int],
119
- start_year: int = None,
120
- end_year: int = None,
121
- venue_type: Union[VENUES, List[VENUES]] = None,
122
- is_findings: Optional[bool] = None
123
- ):
124
- if not isinstance(venue_type, list): venue_type = [venue_type]
125
-
126
- db = mysql.connector.connect(
127
- host = DB_HOSTNAME,
128
- user = "root",
129
- password = "",
130
- database = DB_NAME,
131
- port = DB_PORT
132
- )
133
-
134
- cursor = db.cursor()
135
-
136
- pids_str = ', '.join(['%s'] * len(pids))
137
-
138
- constraints_str = ""
139
- if start_year: constraints_str += f" AND year >= {start_year}"
140
- if end_year: constraints_str += f" AND year <= {end_year}"
141
- if is_findings: constraints_str += f" AND is_findings = {is_findings}"
142
- if venue_type:
143
- venue_str = ','.join([f'"{venue}"' for venue in venue_type])
144
- constraints_str += f" AND venue_type IN ({venue_str})"
145
-
146
- query = PAPER_QUERY.format(
147
- query_arg_str=pids_str,
148
- constraints_str=constraints_str
149
- )
150
-
151
- # print(PAPER_QUERY.format(query_arg_str=', '.join([str(p) for p in pids]), year=year))
152
-
153
- cursor.execute(query, pids)
154
- results = cursor.fetchall()
155
-
156
- if len(results) == 0: return []
157
-
158
- parsed_results = parse_results(results)
159
-
160
- # Restore original ordering of PIDs from ColBERT
161
- results = [parsed_results[pid] for pid in pids if pid in parsed_results.keys()]
162
-
163
- return results
164
-
165
-
166
- if __name__ == '__main__': create_database()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/extras/decompress_residuals.cpp DELETED
@@ -1,160 +0,0 @@
1
- #include <pthread.h>
2
- #include <torch/extension.h>
3
-
4
- typedef struct decompress_args {
5
- int tid;
6
- int nthreads;
7
-
8
- int npids;
9
- int dim;
10
- int packed_dim;
11
- int npacked_vals_per_byte;
12
-
13
- int* pids;
14
- int64_t* lengths;
15
- int64_t* offsets;
16
- float* bucket_weights;
17
- uint8_t* reversed_bit_map;
18
- uint8_t* bucket_weight_combinations;
19
- uint8_t* binary_residuals;
20
- int* codes;
21
- float* centroids;
22
- int64_t* cumulative_lengths;
23
-
24
- float* output;
25
- } decompress_args_t;
26
-
27
- void* decompress(void* args) {
28
- decompress_args_t* decompress_args = (decompress_args_t*)args;
29
-
30
- int npids_per_thread = (int)std::ceil(((float)decompress_args->npids) /
31
- decompress_args->nthreads);
32
- int start = decompress_args->tid * npids_per_thread;
33
- int end = std::min((decompress_args->tid + 1) * npids_per_thread,
34
- decompress_args->npids);
35
-
36
- // Iterate over all documents
37
- for (int i = start; i < end; i++) {
38
- int pid = decompress_args->pids[i];
39
-
40
- // Offset into packed list of token vectors for the given document
41
- int64_t offset = decompress_args->offsets[pid];
42
-
43
- // For each document, iterate over all token vectors
44
- for (int j = 0; j < decompress_args->lengths[pid]; j++) {
45
- const int code = decompress_args->codes[offset + j];
46
-
47
- // For each token vector, iterate over the packed (8-bit) residual
48
- // values
49
- for (int k = 0; k < decompress_args->packed_dim; k++) {
50
- uint8_t x =
51
- decompress_args->binary_residuals
52
- [(offset + j) * decompress_args->packed_dim + k];
53
- x = decompress_args->reversed_bit_map[x];
54
-
55
- // For each packed residual value, iterate over the bucket
56
- // weight indices. If we use n-bit compression, that means there
57
- // will be (8 / n) indices per packed value.
58
- for (int l = 0; l < decompress_args->npacked_vals_per_byte;
59
- l++) {
60
- const int output_dim_idx =
61
- k * decompress_args->npacked_vals_per_byte + l;
62
- const int bucket_weight_idx =
63
- decompress_args->bucket_weight_combinations
64
- [x * decompress_args->npacked_vals_per_byte + l];
65
- decompress_args
66
- ->output[(decompress_args->cumulative_lengths[i] + j) *
67
- decompress_args->dim +
68
- output_dim_idx] =
69
- decompress_args->bucket_weights[bucket_weight_idx] +
70
- decompress_args->centroids[code * decompress_args->dim +
71
- output_dim_idx];
72
- }
73
- }
74
- }
75
- }
76
-
77
- return NULL;
78
- }
79
-
80
- torch::Tensor decompress_residuals(
81
- const torch::Tensor pids, const torch::Tensor lengths,
82
- const torch::Tensor offsets, const torch::Tensor bucket_weights,
83
- const torch::Tensor reversed_bit_map,
84
- const torch::Tensor bucket_weight_combinations,
85
- const torch::Tensor binary_residuals, const torch::Tensor codes,
86
- const torch::Tensor centroids, const int dim, const int nbits) {
87
- const int npacked_vals_per_byte = (8 / nbits);
88
- const int packed_dim = (int)(dim / npacked_vals_per_byte);
89
-
90
- int npids = pids.size(0);
91
- int* pids_a = pids.data_ptr<int>();
92
- int64_t* lengths_a = lengths.data_ptr<int64_t>();
93
- int64_t* offsets_a = offsets.data_ptr<int64_t>();
94
- float* bucket_weights_a = bucket_weights.data_ptr<float>();
95
- uint8_t* reversed_bit_map_a = reversed_bit_map.data_ptr<uint8_t>();
96
- uint8_t* bucket_weight_combinations_a =
97
- bucket_weight_combinations.data_ptr<uint8_t>();
98
- uint8_t* binary_residuals_a = binary_residuals.data_ptr<uint8_t>();
99
- int* codes_a = codes.data_ptr<int>();
100
- float* centroids_a = centroids.data_ptr<float>();
101
-
102
- int64_t cumulative_lengths[npids + 1];
103
- int noutputs = 0;
104
- cumulative_lengths[0] = 0;
105
- for (int i = 0; i < npids; i++) {
106
- noutputs += lengths_a[pids_a[i]];
107
- cumulative_lengths[i + 1] =
108
- cumulative_lengths[i] + lengths_a[pids_a[i]];
109
- }
110
-
111
- auto options =
112
- torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false);
113
- torch::Tensor output = torch::zeros({noutputs, dim}, options);
114
- float* output_a = output.data_ptr<float>();
115
-
116
- auto nthreads = at::get_num_threads();
117
-
118
- pthread_t threads[nthreads];
119
- decompress_args_t args[nthreads];
120
-
121
- for (int i = 0; i < nthreads; i++) {
122
- args[i].tid = i;
123
- args[i].nthreads = nthreads;
124
-
125
- args[i].npids = npids;
126
- args[i].dim = dim;
127
- args[i].packed_dim = packed_dim;
128
- args[i].npacked_vals_per_byte = npacked_vals_per_byte;
129
-
130
- args[i].pids = pids_a;
131
- args[i].lengths = lengths_a;
132
- args[i].offsets = offsets_a;
133
- args[i].bucket_weights = bucket_weights_a;
134
- args[i].reversed_bit_map = reversed_bit_map_a;
135
- args[i].bucket_weight_combinations = bucket_weight_combinations_a;
136
- args[i].binary_residuals = binary_residuals_a;
137
- args[i].codes = codes_a;
138
- args[i].centroids = centroids_a;
139
- args[i].cumulative_lengths = cumulative_lengths;
140
-
141
- args[i].output = output_a;
142
-
143
- int rc = pthread_create(&threads[i], NULL, decompress, (void*)&args[i]);
144
- if (rc) {
145
- fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
146
- std::exit(1);
147
- }
148
- }
149
-
150
- for (int i = 0; i < nthreads; i++) {
151
- pthread_join(threads[i], NULL);
152
- }
153
-
154
- return output;
155
- }
156
-
157
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
158
- m.def("decompress_residuals_cpp", &decompress_residuals,
159
- "Decompress residuals", py::call_guard<py::gil_scoped_release>());
160
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/extras/filter_pids.cpp DELETED
@@ -1,174 +0,0 @@
1
- #include <pthread.h>
2
- #include <torch/extension.h>
3
-
4
- #include <algorithm>
5
- #include <chrono>
6
- #include <numeric>
7
- #include <utility>
8
-
9
- typedef struct maxsim_args {
10
- int tid;
11
- int nthreads;
12
-
13
- int ncentroids;
14
- int nquery_vectors;
15
- int npids;
16
-
17
- int* pids;
18
- float* centroid_scores;
19
- int* codes;
20
- int64_t* doclens;
21
- int64_t* offsets;
22
- bool* idx;
23
-
24
- std::priority_queue<std::pair<float, int>> approx_scores;
25
- } maxsim_args_t;
26
-
27
- void* maxsim(void* args) {
28
- maxsim_args_t* maxsim_args = (maxsim_args_t*)args;
29
-
30
- float per_doc_approx_scores[maxsim_args->nquery_vectors];
31
- for (int k = 0; k < maxsim_args->nquery_vectors; k++) {
32
- per_doc_approx_scores[k] = -9999;
33
- }
34
-
35
- int ndocs_per_thread =
36
- (int)std::ceil(((float)maxsim_args->npids) / maxsim_args->nthreads);
37
- int start = maxsim_args->tid * ndocs_per_thread;
38
- int end =
39
- std::min((maxsim_args->tid + 1) * ndocs_per_thread, maxsim_args->npids);
40
-
41
- std::unordered_set<int> seen_codes;
42
-
43
- for (int i = start; i < end; i++) {
44
- auto pid = maxsim_args->pids[i];
45
- for (int j = 0; j < maxsim_args->doclens[pid]; j++) {
46
- auto code = maxsim_args->codes[maxsim_args->offsets[pid] + j];
47
- assert(code < maxsim_args->ncentroids);
48
- if (maxsim_args->idx[code] &&
49
- seen_codes.find(code) == seen_codes.end()) {
50
- for (int k = 0; k < maxsim_args->nquery_vectors; k++) {
51
- per_doc_approx_scores[k] =
52
- std::max(per_doc_approx_scores[k],
53
- maxsim_args->centroid_scores
54
- [code * maxsim_args->nquery_vectors + k]);
55
- }
56
- seen_codes.insert(code);
57
- }
58
- }
59
- float score = 0;
60
- for (int k = 0; k < maxsim_args->nquery_vectors; k++) {
61
- score += per_doc_approx_scores[k];
62
- per_doc_approx_scores[k] = -9999;
63
- }
64
- maxsim_args->approx_scores.push(std::make_pair(score, pid));
65
- seen_codes.clear();
66
- }
67
-
68
- return NULL;
69
- }
70
-
71
- std::vector<int> filter_pids_helper(int ncentroids, int nquery_vectors, int npids,
72
- int* pids, float* centroid_scores, int* codes,
73
- int64_t* doclens, int64_t* offsets, bool* idx,
74
- int nfiltered_docs) {
75
- auto nthreads = at::get_num_threads();
76
-
77
- pthread_t threads[nthreads];
78
- maxsim_args_t args[nthreads];
79
-
80
- for (int i = 0; i < nthreads; i++) {
81
- args[i].tid = i;
82
- args[i].nthreads = nthreads;
83
-
84
- args[i].ncentroids = ncentroids;
85
- args[i].nquery_vectors = nquery_vectors;
86
- args[i].npids = npids;
87
-
88
- args[i].pids = pids;
89
- args[i].centroid_scores = centroid_scores;
90
- args[i].codes = codes;
91
- args[i].doclens = doclens;
92
- args[i].offsets = offsets;
93
- args[i].idx = idx;
94
-
95
- args[i].approx_scores = std::priority_queue<std::pair<float, int>>();
96
-
97
- int rc = pthread_create(&threads[i], NULL, maxsim, (void*)&args[i]);
98
- if (rc) {
99
- fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
100
- std::exit(1);
101
- }
102
- }
103
-
104
- for (int i = 0; i < nthreads; i++) {
105
- pthread_join(threads[i], NULL);
106
- }
107
-
108
- std::priority_queue<std::pair<float, int>> global_approx_scores;
109
- for (int i = 0; i < nthreads; i++) {
110
- for (int j = 0; j < nfiltered_docs; j++) {
111
- if (args[i].approx_scores.empty()) {
112
- break;
113
- }
114
- global_approx_scores.push(args[i].approx_scores.top());
115
- args[i].approx_scores.pop();
116
- }
117
- }
118
-
119
- std::vector<int> filtered_pids;
120
- for (int i = 0; i < nfiltered_docs; i++) {
121
- if (global_approx_scores.empty()) {
122
- break;
123
- }
124
- std::pair<float, int> score_and_pid = global_approx_scores.top();
125
- global_approx_scores.pop();
126
- filtered_pids.push_back(score_and_pid.second);
127
- }
128
-
129
- return filtered_pids;
130
- }
131
-
132
- torch::Tensor filter_pids(const torch::Tensor pids,
133
- const torch::Tensor centroid_scores,
134
- const torch::Tensor codes,
135
- const torch::Tensor doclens,
136
- const torch::Tensor offsets, const torch::Tensor idx,
137
- int nfiltered_docs) {
138
- auto ncentroids = centroid_scores.size(0);
139
- auto nquery_vectors = centroid_scores.size(1);
140
- auto npids = pids.size(0);
141
-
142
- auto pids_a = pids.data_ptr<int>();
143
- auto centroid_scores_a = centroid_scores.data_ptr<float>();
144
- auto codes_a = codes.data_ptr<int>();
145
- auto doclens_a = doclens.data_ptr<int64_t>();
146
- auto offsets_a = offsets.data_ptr<int64_t>();
147
- auto idx_a = idx.data_ptr<bool>();
148
-
149
- std::vector<int> filtered_pids = filter_pids_helper(ncentroids, nquery_vectors, npids, pids_a,
150
- centroid_scores_a, codes_a, doclens_a, offsets_a, idx_a,
151
- nfiltered_docs);
152
-
153
- int nfinal_filtered_docs = (int)(nfiltered_docs / 4);
154
- bool ones[ncentroids];
155
- for (int i = 0; i < ncentroids; i++) {
156
- ones[i] = true;
157
- }
158
-
159
- int* filtered_pids_a = filtered_pids.data();
160
- auto nfiltered_pids = filtered_pids.size();
161
- std::vector<int> final_filtered_pids = filter_pids_helper(ncentroids, nquery_vectors, nfiltered_pids,
162
- filtered_pids_a, centroid_scores_a, codes_a, doclens_a,
163
- offsets_a, ones, nfinal_filtered_docs);
164
-
165
- auto options =
166
- torch::TensorOptions().dtype(torch::kInt32).requires_grad(false);
167
- return torch::from_blob(final_filtered_pids.data(), {(int)final_filtered_pids.size()},
168
- options)
169
- .clone();
170
- }
171
-
172
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
173
- m.def("filter_pids_cpp", &filter_pids, "Filter pids", py::call_guard<py::gil_scoped_release>());
174
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/extras/segmented_lookup.cpp DELETED
@@ -1,148 +0,0 @@
1
- #include <pthread.h>
2
- #include <torch/extension.h>
3
-
4
- #include <algorithm>
5
- #include <numeric>
6
-
7
- typedef struct {
8
- int tid;
9
- pthread_mutex_t* mutex;
10
- std::queue<int>* queue;
11
-
12
- int64_t ndocs;
13
- int64_t noutputs;
14
- int64_t dim;
15
-
16
- void* input;
17
- int64_t* lengths;
18
- int64_t* offsets;
19
- int64_t* cumulative_lengths;
20
-
21
- void* output;
22
- } lookup_args_t;
23
-
24
- template <typename T>
25
- void* lookup(void* args) {
26
- lookup_args_t* lookup_args = (lookup_args_t*)args;
27
-
28
- int64_t* lengths = lookup_args->lengths;
29
- int64_t* cumulative_lengths = lookup_args->cumulative_lengths;
30
- int64_t* offsets = lookup_args->offsets;
31
- int64_t dim = lookup_args->dim;
32
-
33
- T* input = static_cast<T*>(lookup_args->input);
34
- T* output = static_cast<T*>(lookup_args->output);
35
-
36
- while (1) {
37
- pthread_mutex_lock(lookup_args->mutex);
38
- if (lookup_args->queue->empty()) {
39
- pthread_mutex_unlock(lookup_args->mutex);
40
- return NULL;
41
- }
42
- int i = lookup_args->queue->front();
43
- lookup_args->queue->pop();
44
- pthread_mutex_unlock(lookup_args->mutex);
45
-
46
- std::memcpy(output + (cumulative_lengths[i] * dim),
47
- input + (offsets[i] * dim), lengths[i] * dim * sizeof(T));
48
- }
49
- }
50
-
51
- template <typename T>
52
- torch::Tensor segmented_lookup_impl(const torch::Tensor input,
53
- const torch::Tensor pids,
54
- const torch::Tensor lengths,
55
- const torch::Tensor offsets) {
56
- auto lengths_a = lengths.data_ptr<int64_t>();
57
- auto offsets_a = offsets.data_ptr<int64_t>();
58
-
59
- int64_t ndocs = pids.size(0);
60
- int64_t noutputs = std::accumulate(lengths_a, lengths_a + ndocs, 0);
61
-
62
- int nthreads = at::get_num_threads();
63
-
64
- int64_t dim;
65
- torch::Tensor output;
66
-
67
- if (input.dim() == 1) {
68
- dim = 1;
69
- output = torch::zeros({noutputs}, input.options());
70
- } else {
71
- assert(input.dim() == 2);
72
- dim = input.size(1);
73
- output = torch::zeros({noutputs, dim}, input.options());
74
- }
75
-
76
- int64_t cumulative_lengths[ndocs + 1];
77
- cumulative_lengths[0] = 0;
78
- std::partial_sum(lengths_a, lengths_a + ndocs, cumulative_lengths + 1);
79
-
80
- pthread_mutex_t mutex;
81
- int rc = pthread_mutex_init(&mutex, NULL);
82
- if (rc) {
83
- fprintf(stderr, "Unable to init mutex: %d\n", rc);
84
- }
85
-
86
- std::queue<int> queue;
87
- for (int i = 0; i < ndocs; i++) {
88
- queue.push(i);
89
- }
90
-
91
- pthread_t threads[nthreads];
92
- lookup_args_t args[nthreads];
93
- for (int i = 0; i < nthreads; i++) {
94
- args[i].tid = i;
95
- args[i].mutex = &mutex;
96
- args[i].queue = &queue;
97
-
98
- args[i].ndocs = ndocs;
99
- args[i].noutputs = noutputs;
100
- args[i].dim = dim;
101
-
102
- args[i].input = (void*)input.data_ptr<T>();
103
- args[i].lengths = lengths_a;
104
- args[i].offsets = offsets_a;
105
- args[i].cumulative_lengths = cumulative_lengths;
106
-
107
- args[i].output = (void*)output.data_ptr<T>();
108
-
109
- rc = pthread_create(&threads[i], NULL, lookup<T>, (void*)&args[i]);
110
- if (rc) {
111
- fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
112
- }
113
- }
114
-
115
- for (int i = 0; i < nthreads; i++) {
116
- pthread_join(threads[i], NULL);
117
- }
118
-
119
- rc = pthread_mutex_destroy(&mutex);
120
- if (rc) {
121
- fprintf(stderr, "Unable to destroy mutex: %d\n", rc);
122
- }
123
-
124
- return output;
125
- }
126
-
127
- torch::Tensor segmented_lookup(const torch::Tensor input,
128
- const torch::Tensor pids,
129
- const torch::Tensor lengths,
130
- const torch::Tensor offsets) {
131
- if (input.dtype() == torch::kUInt8) {
132
- return segmented_lookup_impl<uint8_t>(input, pids, lengths, offsets);
133
- } else if (input.dtype() == torch::kInt32) {
134
- return segmented_lookup_impl<int>(input, pids, lengths, offsets);
135
- } else if (input.dtype() == torch::kInt64) {
136
- return segmented_lookup_impl<int64_t>(input, pids, lengths, offsets);
137
- } else if (input.dtype() == torch::kFloat32) {
138
- return segmented_lookup_impl<float>(input, pids, lengths, offsets);
139
- } else if (input.dtype() == torch::kFloat16) {
140
- return segmented_lookup_impl<at::Half>(input, pids, lengths, offsets);
141
- } else {
142
- assert(false);
143
- }
144
- }
145
-
146
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
147
- m.def("segmented_lookup_cpp", &segmented_lookup, "Segmented lookup");
148
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/extras/segmented_maxsim.cpp DELETED
@@ -1,97 +0,0 @@
1
- #include <pthread.h>
2
- #include <torch/extension.h>
3
-
4
- #include <algorithm>
5
- #include <numeric>
6
-
7
- typedef struct {
8
- int tid;
9
- int nthreads;
10
-
11
- int ndocs;
12
- int ndoc_vectors;
13
- int nquery_vectors;
14
-
15
- int64_t* lengths;
16
- float* scores;
17
- int64_t* offsets;
18
-
19
- float* max_scores;
20
- } max_args_t;
21
-
22
- void* max(void* args) {
23
- max_args_t* max_args = (max_args_t*)args;
24
-
25
- int ndocs_per_thread =
26
- std::ceil(((float)max_args->ndocs) / max_args->nthreads);
27
- int start = max_args->tid * ndocs_per_thread;
28
- int end = std::min((max_args->tid + 1) * ndocs_per_thread, max_args->ndocs);
29
-
30
- auto max_scores_offset =
31
- max_args->max_scores + (start * max_args->nquery_vectors);
32
- auto scores_offset =
33
- max_args->scores + (max_args->offsets[start] * max_args->nquery_vectors);
34
-
35
- for (int i = start; i < end; i++) {
36
- for (int j = 0; j < max_args->lengths[i]; j++) {
37
- std::transform(max_scores_offset,
38
- max_scores_offset + max_args->nquery_vectors,
39
- scores_offset, max_scores_offset,
40
- [](float a, float b) { return std::max(a, b); });
41
- scores_offset += max_args->nquery_vectors;
42
- }
43
- max_scores_offset += max_args->nquery_vectors;
44
- }
45
-
46
- return NULL;
47
- }
48
-
49
- torch::Tensor segmented_maxsim(const torch::Tensor scores,
50
- const torch::Tensor lengths) {
51
- auto lengths_a = lengths.data_ptr<int64_t>();
52
- auto scores_a = scores.data_ptr<float>();
53
- auto ndocs = lengths.size(0);
54
- auto ndoc_vectors = scores.size(0);
55
- auto nquery_vectors = scores.size(1);
56
- auto nthreads = at::get_num_threads();
57
-
58
- torch::Tensor max_scores =
59
- torch::zeros({ndocs, nquery_vectors}, scores.options());
60
-
61
- int64_t offsets[ndocs + 1];
62
- offsets[0] = 0;
63
- std::partial_sum(lengths_a, lengths_a + ndocs, offsets + 1);
64
-
65
- pthread_t threads[nthreads];
66
- max_args_t args[nthreads];
67
-
68
- for (int i = 0; i < nthreads; i++) {
69
- args[i].tid = i;
70
- args[i].nthreads = nthreads;
71
-
72
- args[i].ndocs = ndocs;
73
- args[i].ndoc_vectors = ndoc_vectors;
74
- args[i].nquery_vectors = nquery_vectors;
75
-
76
- args[i].lengths = lengths_a;
77
- args[i].scores = scores_a;
78
- args[i].offsets = offsets;
79
-
80
- args[i].max_scores = max_scores.data_ptr<float>();
81
-
82
- int rc = pthread_create(&threads[i], NULL, max, (void*)&args[i]);
83
- if (rc) {
84
- fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
85
- }
86
- }
87
-
88
- for (int i = 0; i < nthreads; i++) {
89
- pthread_join(threads[i], NULL);
90
- }
91
-
92
- return max_scores.sum(1);
93
- }
94
-
95
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
96
- m.def("segmented_maxsim_cpp", &segmented_maxsim, "Segmented MaxSim");
97
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/index.py DELETED
@@ -1,67 +0,0 @@
1
- import os
2
- os.environ["TOKENIZERS_PARALLELISM"] = "false" # Prevents deadlocks in ColBERT tokenization
3
- os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # Allows multiple libraries in OpenMP runtime. This can cause unexected behavior, but allows ColBERT to work
4
-
5
- import json
6
-
7
- from constants import INDEX_NAME, DATASET_PATH
8
-
9
- from colbert import Indexer, Searcher
10
- from colbert.infra import Run, RunConfig, ColBERTConfig
11
-
12
-
13
- nbits = 2 # encode each dimension with 2 bits
14
- doc_maxlen = 512 # truncate passages
15
- checkpoint = 'colbert-ir/colbertv2.0' # ColBERT model to use
16
-
17
-
18
- def index_anthology(collection, index_name):
19
- with Run().context(RunConfig(nranks=2, experiment='notebook')): # nranks specifies the number of GPUs to use
20
- config = ColBERTConfig(
21
- doc_maxlen=doc_maxlen,
22
- nbits=nbits,
23
- kmeans_niters=4, # specifies the number of iterations of k-means clustering; 4 is a good and fast default.
24
- index_path=INDEX_NAME,
25
- bsize=1024
26
- )
27
-
28
- indexer = Indexer(
29
- checkpoint=checkpoint,
30
- config=config
31
- )
32
-
33
- indexer.index(
34
- name=index_name,
35
- collection=collection,
36
- overwrite=True
37
- )
38
-
39
-
40
- def search_anthology(query, collection, index_name):
41
- """ Default ColBERT search function """
42
- with Run().context(RunConfig(nranks=0, experiment='notebook')):
43
- searcher = Searcher(index=index_name, collection=collection)
44
-
45
- results = searcher.search(query, k=3)
46
-
47
- for passage_id, passage_rank, passage_score in zip(*results):
48
- print(f"\t [{passage_rank}] \t\t {passage_score:.1f} \t\t {searcher.collection[passage_id]}")
49
-
50
-
51
- def main():
52
- # Load the parsed anthology
53
- with open(DATASET_PATH, 'r', encoding='utf-8') as f:
54
- dataset = json.loads(f.read())
55
-
56
- # Get the abstracts for indexing
57
- collection = [e['abstract'] for e in dataset]
58
-
59
- # Run ColBERT indexer
60
- index_anthology(collection, index_name=INDEX_NAME)
61
-
62
- # Sanity check
63
- # query = ["What are some recent examples of grammar checkers?"]
64
- # search_anthology(query, collection, index_name=INDEX_NAME)
65
-
66
-
67
- if __name__ == '__main__': main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/parse.py DELETED
@@ -1,113 +0,0 @@
1
- import bibtexparser, json
2
-
3
- from constants import ANTHOLOGY_PATH, DATASET_PATH
4
-
5
-
6
- def parse_bibtex(anthology_path, dataset_path):
7
- with open(anthology_path, 'r', encoding='utf-8') as f:
8
- bib = bibtexparser.load(f)
9
- dataset = bib.entries
10
-
11
- print(f'Found {len(dataset)} articles with keys: {dataset[0].keys()}')
12
- paper: dict
13
- for paper in dataset[:2]:
14
- print(f"{paper.get('author')}\n{paper.get('title')}\n{paper.get('url')}\n")
15
-
16
- # Remove any entries without abstracts, since we index on abstracts
17
- dataset = [paper for paper in dataset if 'abstract' in paper.keys()]
18
-
19
- with open(dataset_path, 'w', encoding='utf-8') as f:
20
- f.write(json.dumps(dataset, indent=4))
21
-
22
- return dataset
23
-
24
-
25
- def preprocess_acl_entries(dataset_path):
26
- """
27
- Very rough attempt at using ACL URLs to infer their venues. Bless this mess.
28
- """
29
- with open(dataset_path, 'r', encoding='utf-8') as f:
30
- dataset = json.loads(f.read())
31
-
32
- venues = []
33
- for id, paper in enumerate(dataset):
34
- url = paper['url']
35
- year = int(paper['year'])
36
-
37
- if year < 2020:
38
- dataset[id]['findings'] = None
39
- dataset[id]['venue_type'] = None
40
- continue
41
-
42
- if 'https://aclanthology.org/' in url:
43
- url = url.split('https://aclanthology.org/')[1]
44
- elif 'http://www.lrec-conf.org/proceedings/' in url:
45
- url = url.split('http://www.lrec-conf.org/proceedings/')[1]
46
-
47
- if year >= 2020:
48
- # new URL format
49
-
50
- url_new = '.'.join(url.split('.')[:-1])
51
- if url_new != '': url = url_new
52
-
53
- # For most new venues, the format is "2023.eacl-tutorials" -> "eacl-tutorials"
54
- url_new = '.'.join(url.split('.')[1:])
55
- if url_new != '': url = url_new
56
-
57
- # 'acl-main' -> 'acl-long'?
58
- # 'acl-main' -> 'acl-short'?
59
-
60
- # 'eacl-demo' -> 'eacl-demos'
61
- # 'emnlp-tutorial' -> 'emnlp-tutorials'
62
- url = url.replace('-demos', '-demo')
63
- url = url.replace('-tutorials', '-tutorial')
64
-
65
- elif year >= 2016:
66
- # old URL format
67
- # P17-1001 -> P17
68
-
69
- url = url.split('-')[0]
70
-
71
- raise RuntimeError('not working')
72
-
73
- venues += [url]
74
-
75
- # Extract paper type from URL
76
- _type = None
77
- if any(venue in url for venue in ['parlaclarin', 'nlpcovid19']):
78
- _type = 'workshop'
79
- elif not any(venue in url for venue in ['aacl', 'naacl', 'acl', 'emnlp', 'eacl', 'tacl']):
80
- _type = 'workshop'
81
- elif 'tacl' in url: _type = 'journal'
82
- elif 'srw' in url: _type = 'workshop'
83
- elif 'short' in url: _type = 'short'
84
- elif 'demo' in url: _type = 'demo'
85
- elif 'tutorial' in url: _type = 'tutorial'
86
- elif 'industry' in url: _type = 'industry'
87
- elif 'findings' in url: _type = 'findings'
88
- elif 'main' in url or 'long' in url: _type = 'main'
89
- else:
90
- print(f'Could not parse: {url}')
91
-
92
- findings = ('findings' in url)
93
-
94
- dataset[id]['findings'] = findings
95
- dataset[id]['venue_type'] = _type
96
-
97
- # print(set(venues))
98
-
99
- with open(DATASET_PATH, 'w', encoding='utf-8') as f:
100
- f.write(json.dumps(dataset, indent=4))
101
-
102
- return dataset
103
-
104
-
105
- def main():
106
- # 1) Parse and save the anthology dataset
107
- dataset = parse_bibtex(ANTHOLOGY_PATH, DATASET_PATH)
108
-
109
- # 2) Pre-process the ACL anthology
110
- dataset = preprocess_acl_entries(DATASET_PATH)
111
-
112
-
113
- if __name__ == '__main__': main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/search.py DELETED
@@ -1,204 +0,0 @@
1
- import os, ujson, tqdm
2
- import torch
3
- import torch.nn.functional as F
4
-
5
- from colbert import Checkpoint
6
- from colbert.infra.config import ColBERTConfig
7
- from colbert.search.index_storage import IndexScorer
8
- from colbert.search.strided_tensor import StridedTensor
9
- from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
10
- from colbert.indexing.codecs.residual import ResidualCodec
11
-
12
-
13
- NCELLS = 1 # Number of centroids to use in PLAID
14
- CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
15
- NDOCS = 512 # Number of closest documents to consider
16
-
17
-
18
- def init_colbert(index_path, load_index_with_mmap=False):
19
- """
20
- Load all tensors necessary for running ColBERT
21
- """
22
- global index_checkpoint, scorer, centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets
23
-
24
- # index_checkpoint: Checkpoint
25
-
26
- use_gpu = torch.cuda.is_available()
27
- if use_gpu:
28
- device = 'cuda'
29
- else:
30
- device = 'cpu'
31
-
32
- # Load index checkpoint
33
- from colbert.infra.run import Run
34
- initial_config = ColBERTConfig.from_existing(None, Run().config)
35
- index_config = ColBERTConfig.load_from_index(index_path)
36
- checkpoint_path = index_config.checkpoint
37
- checkpoint_config = ColBERTConfig.load_from_checkpoint(checkpoint_path)
38
- config: ColBERTConfig = ColBERTConfig.from_existing(checkpoint_config, index_config, initial_config)
39
-
40
- index_checkpoint = Checkpoint(checkpoint_path, colbert_config=config)
41
- index_checkpoint = index_checkpoint.to(device)
42
-
43
- load_index_with_mmap = config.load_index_with_mmap
44
- if load_index_with_mmap and use_gpu:
45
- raise ValueError(f"Memory-mapped index can only be used with CPU!")
46
-
47
- scorer = IndexScorer(index_path, use_gpu, load_index_with_mmap)
48
-
49
- with open(os.path.join(index_path, 'metadata.json')) as f:
50
- metadata = ujson.load(f)
51
- nbits = metadata['config']['nbits']
52
-
53
- centroids = torch.load(os.path.join(index_path, 'centroids.pt'), map_location=device)
54
- centroids = centroids.float()
55
-
56
- ivf, ivf_lengths = torch.load(os.path.join(index_path, "ivf.pid.pt"), map_location=device)
57
- ivf = StridedTensor(ivf, ivf_lengths, use_gpu=False)
58
-
59
- embeddings = ResidualCodec.Embeddings.load_chunks(
60
- index_path,
61
- range(metadata['num_chunks']),
62
- metadata['num_embeddings'],
63
- load_index_with_mmap=load_index_with_mmap,
64
- )
65
-
66
- doclens = []
67
- for chunk_idx in tqdm.tqdm(range(metadata['num_chunks'])):
68
- with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f:
69
- chunk_doclens = ujson.load(f)
70
- doclens.extend(chunk_doclens)
71
- doclens = torch.tensor(doclens)
72
-
73
- buckets_path = os.path.join(index_path, 'buckets.pt')
74
- bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location=device)
75
- bucket_weights = bucket_weights.float()
76
-
77
- codec = ResidualCodec.load(index_path)
78
-
79
- if load_index_with_mmap:
80
- assert metadata['num_chunks'] == 1
81
- offsets = torch.cumsum(doclens, dim=0)
82
- offsets = torch.cat((torch.zeros(1, dtype=torch.int64), offsets))
83
- else:
84
- embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens)
85
- offsets = embeddings_strided.codes_strided.offsets
86
-
87
-
88
- def colbert_score(Q: torch.Tensor, D_padded: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor:
89
- """
90
- Computes late interaction between question (Q) and documents (D)
91
- See Figure 1: https://aclanthology.org/2022.naacl-main.272.pdf#page=3
92
- """
93
- assert Q.dim() == 3, Q.size()
94
- assert D_padded.dim() == 3, D_padded.size()
95
- assert Q.size(0) in [1, D_padded.size(0)]
96
-
97
- scores_padded = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)
98
-
99
- D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
100
- scores_padded[D_padding] = -9999
101
- scores = scores_padded.max(1).values
102
- scores = scores.sum(-1)
103
-
104
- return scores
105
-
106
-
107
- def get_candidates(Q: torch.Tensor, ivf: StridedTensor) -> torch.Tensor:
108
- """
109
- First find centroids closest to Q, then return all the passages in all
110
- centroids.
111
-
112
- We can replace this function with a k-NN search finding the closest passages
113
- using BERT similarity.
114
- """
115
- Q = Q.squeeze(0)
116
-
117
- # Get the closest centroids via a matrix multiplication + argmax
118
- centroid_scores: torch.Tensor = (centroids @ Q.T)
119
- if NCELLS == 1:
120
- cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
121
- else:
122
- cells = centroid_scores.topk(NCELLS, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells)
123
- cells = cells.flatten().contiguous() # (32 * ncells,)
124
- cells = cells.unique(sorted=False)
125
-
126
- # Given the relevant clusters, get all passage IDs in each cluster
127
- # Note, this may return duplicates since passages can exist in multiple clusters
128
- pids, _ = ivf.lookup(cells)
129
-
130
- # Sort and retun values
131
- pids = pids.sort().values
132
- pids, _ = torch.unique_consecutive(pids, return_counts=True)
133
- return pids, centroid_scores
134
-
135
-
136
- def _calculate_colbert(Q: torch.Tensor):
137
- """
138
- Multi-stage ColBERT pipeline. Implemented using the PLAID engine, see fig. 5:
139
- https://arxiv.org/pdf/2205.09707#page=5
140
- """
141
- # Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
142
- unfiltered_pids, centroid_scores = get_candidates(Q, ivf)
143
- print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')
144
-
145
- # Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
146
- idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
147
-
148
- # C++ : Filter pids under the centroid score threshold
149
- pids_true = scorer.filter_pids(
150
- unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
151
- )
152
- pids = pids_true
153
- assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
154
- print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)
155
-
156
- # Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
157
- D_packed = scorer.decompress_residuals(
158
- pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
159
- codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
160
- centroids, codec.dim, nbits
161
- )
162
- D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
163
- D_mask = doclens[pids.long()]
164
- D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor()
165
- print('Stage 3.5 decompression:', pids.shape, '->', D_padded.shape) # (n_docs/4) -> (n_docs/4, num_toks, hidden_dim)
166
-
167
- # Stage 4 (Final Ranking w/ Decompression) - Calculate the final (expensive) maxsim scores with ColBERT
168
- scores = colbert_score(Q, D_padded, D_lengths)
169
- print('Stage 4 ranking:', D_padded.shape, '->', scores.shape)
170
-
171
- return scores, pids
172
-
173
-
174
- def encode(text, full_length_search=False) -> torch.Tensor:
175
- queries = text if isinstance(text, list) else [text]
176
- bsize = 128 if len(queries) > 128 else None
177
-
178
- Q = index_checkpoint.queryFromText(
179
- queries,
180
- bsize=bsize,
181
- to_cpu=True,
182
- full_length_search=full_length_search
183
- )
184
-
185
- QUERY_MAX_LEN = index_checkpoint.query_tokenizer.query_maxlen
186
- Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
187
-
188
- return Q
189
-
190
-
191
- def search_colbert(query):
192
- """
193
- ColBERT search with a query.
194
- """
195
- # Encode query using ColBERT model, using the appropriate [Q], [D] tokens
196
- Q = encode(query)
197
-
198
- scores, pids = _calculate_colbert(Q)
199
-
200
- # Sort values
201
- scores_sorter = scores.sort(descending=True)
202
- pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist()
203
-
204
- return pids, scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/server.py DELETED
@@ -1,106 +0,0 @@
1
- import os, math, re
2
- from typing import List, Optional, Union
3
-
4
- from flask import Flask, abort, request, render_template
5
- from functools import lru_cache
6
-
7
- from constants import INDEX_PATH, VENUES
8
-
9
- from search import init_colbert, search_colbert
10
- from db import create_database, query_paper_metadata
11
-
12
- PORT = int(os.getenv("PORT", 8893))
13
- app = Flask(__name__)
14
-
15
-
16
- @lru_cache(maxsize=1000000)
17
- def api_search_query(query):
18
- print(f"Query={query}")
19
-
20
- # Use ColBERT to find passages related to the query
21
- pids, scores = search_colbert(query)
22
-
23
- # Softmax output probs
24
- probs = [math.exp(s) for s in scores]
25
- probs = [p / sum(probs) for p in probs]
26
-
27
- # Sort and return results as a dict
28
- topk = [{'pid': pid, 'score': score, 'prob': prob} for pid, score, prob in zip(pids, scores, probs)]
29
- topk = sorted(topk, key=lambda p: (p['score'], p['pid']), reverse=True)
30
-
31
- response = {"query" : query, "topk": topk}
32
-
33
- return response
34
-
35
-
36
- def is_valid_query(query):
37
- return re.match(r'^[a-zA-Z0-9 ]*$', query) and len(query) <= 256
38
-
39
-
40
- @app.route("/api/colbert", methods=["GET"])
41
- def api_search():
42
- if request.method == "GET":
43
- query = str(request.args.get('query'))
44
- if not is_valid_query(query): abort(400, "Invalid query :(")
45
- return api_search_query(query)
46
- return ('', 405)
47
-
48
-
49
- @app.route('/api/search', methods=['POST', 'GET'])
50
- def query():
51
- query: str
52
- start_year: Optional[int]
53
- end_year: Optional[int]
54
- venue_type: Optional[Union[VENUES, List[VENUES]]]
55
- is_findings: Optional[bool]
56
-
57
- if request.method in ["POST", "GET"]:
58
- args = request.form if request.method == "POST" else request.args
59
- query = args.get('query')
60
- start_year = args.get('start_year', None)
61
- end_year = args.get('end_year', None)
62
- venue_type = args.getlist('venue_type', None)
63
- is_findings = args.get('is_findings', None)
64
-
65
- if not is_valid_query(query):
66
- abort(400, "Invalid query :(")
67
-
68
- # Get top passage IDs from ColBERT
69
- colbert_response = api_search_query(query)
70
-
71
- # Query MySQL database for paper information
72
- pids = [r['pid'] for r in colbert_response["topk"]]
73
- mysql_response = query_paper_metadata(
74
- pids,
75
- start_year=start_year,
76
- end_year=end_year,
77
- venue_type=venue_type,
78
- is_findings=is_findings
79
- )
80
-
81
- K = 20
82
- mysql_response = mysql_response[:K]
83
-
84
- return mysql_response
85
-
86
-
87
- # @app.route('/search', methods=['POST', 'GET'])
88
- # def search_web():
89
- # return render_template('public/results.html', query=query, year=year, results=results)
90
-
91
-
92
- @app.route('/', methods=['POST', 'GET'])
93
- def index():
94
- return render_template('index.html')
95
-
96
-
97
- if __name__ == "__main__":
98
- """
99
- Example usage:
100
- python server.py
101
- http://localhost:8893/api/colbert?query=Information retrevial with BERT
102
- http://localhost:8893/api/search?query=Information retrevial with BERT
103
- """
104
- create_database()
105
- init_colbert(index_path=INDEX_PATH)
106
- app.run("0.0.0.0", PORT) # debug=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/static/style.css DELETED
@@ -1,83 +0,0 @@
1
- :root {
2
- --custom-red: #ED1C24;
3
- --custom-red-dark: #D11920;
4
- --custom-blue: #446e9b;
5
- }
6
- .card {
7
- margin-bottom: 20px;
8
- }
9
- .search-container {
10
- display: flex;
11
- margin-bottom: 20px;
12
- }
13
- .search-container .form-control {
14
- margin-right: 10px;
15
- }
16
- .btn-primary {
17
- background-color: var(--custom-red);
18
- border-color: var(--custom-red);
19
- }
20
- .btn-primary:hover, .btn-primary:focus, .btn-primary:active {
21
- background-color: var(--custom-red-dark);
22
- border-color: var(--custom-red-dark);
23
- }
24
- /* Custom styling for range input */
25
- input[type="range"] {
26
- width: 100%;
27
- height: 8px;
28
- border-radius: 5px;
29
- background: #b6b6b6;
30
- outline: none;
31
- }
32
- input[type="range"]::-webkit-slider-thumb {
33
- -webkit-appearance: none;
34
- appearance: none;
35
- width: 20px;
36
- height: 20px;
37
- border-radius: 50%;
38
- background: var(--custom-red);
39
- cursor: pointer;
40
- }
41
- input[type="range"]::-moz-range-thumb {
42
- width: 20px;
43
- height: 20px;
44
- border-radius: 50%;
45
- background: var(--custom-red);
46
- cursor: pointer;
47
- }
48
- /* Custom styling for checkboxes */
49
- .form-check-input:checked {
50
- background-color: var(--custom-red);
51
- border-color: var(--custom-red);
52
- }
53
- .form-check-input:focus {
54
- border-color: var(--custom-red);
55
- box-shadow: 0 0 0 0.25rem rgba(237, 28, 36, 0.25);
56
- }
57
- /* Custom styling for links */
58
- a {
59
- color: var(--custom-blue) !important;
60
- text-decoration: none;
61
- }
62
- a:hover, a:focus {
63
- color: #446e9b !important;
64
- text-decoration: underline;
65
- }
66
- /* David custom */
67
- .card-text:last-child {
68
- font-size: 11pt;
69
- line-height: 1.05 !important;
70
- }
71
- .card {
72
- background-color: #f8f9fa !important;
73
- color: #212529
74
- }
75
- .card-body h6 {
76
- font-size: 11pt;
77
- }
78
- .paper-metadata {
79
- font-size: 10pt;
80
- }
81
- .range-label {
82
- text-align: center
83
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/templates/index.html DELETED
@@ -1,137 +0,0 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
-
4
- <head>
5
- <meta charset="UTF-8">
6
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
- <title>ACL Search</title>
8
- <link href="https://cdnjs.cloudflare.com/ajax/libs/bootstrap/5.3.0/css/bootstrap.min.css" rel="stylesheet">
9
- <link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
10
- <script src="https://cdnjs.cloudflare.com/ajax/libs/axios/0.21.1/axios.min.js"></script>
11
- </head>
12
-
13
- <body>
14
- <div class="container mt-5">
15
- <form id="searchForm">
16
- <div class="row">
17
- <div class="col-md-3">
18
- <div class="mb-3">
19
- <h1 class="mb-4">ACL Search</h1>
20
- <label for="yearRange" class="form-label range-label"><span id="yearRangeValue"></span></label>
21
- <input type="range" class="form-range" id="yearRange" min="2010" max="2024" step="1"
22
- value="2021">
23
- </div>
24
-
25
- <div class="mb-3">
26
- <div class="form-check">
27
- <input class="form-check-input" type="checkbox" value="main" id="main" checked>
28
- <label class="form-check-label" for="main">Main Long</label>
29
- </div>
30
- <div class="form-check">
31
- <input class="form-check-input" type="checkbox" value="short" id="short" checked>
32
- <label class="form-check-label" for="short">Main Short</label>
33
- </div>
34
- <div class="form-check">
35
- <input class="form-check-input" type="checkbox" value="findings" id="findings">
36
- <label class="form-check-label" for="findings">Findings</label>
37
- </div>
38
- <div class="form-check">
39
- <input class="form-check-input" type="checkbox" value="journal" id="journal">
40
- <label class="form-check-label" for="journal">Journal</label>
41
- </div>
42
- <div class="form-check">
43
- <input class="form-check-input" type="checkbox" value="workshop" id="workshop">
44
- <label class="form-check-label" for="workshop">Workshop</label>
45
- </div>
46
- <div class="form-check">
47
- <input class="form-check-input" type="checkbox" value="demo" id="demo">
48
- <label class="form-check-label" for="demo">Demo Track</label>
49
- </div>
50
- <div class="form-check">
51
- <input class="form-check-input" type="checkbox" value="industry" id="industry">
52
- <label class="form-check-label" for="industry">Industry Track</label>
53
- </div>
54
- <div class="form-check">
55
- <input class="form-check-input" type="checkbox" value="tutorial" id="tutorial">
56
- <label class="form-check-label" for="tutorial">Tutorial Abstracts</label>
57
- </div>
58
- </div>
59
- </div>
60
- <div class="col-md-9">
61
- <div class="search-container">
62
- <input type="text" class="form-control" id="query" name="query" placeholder="Information is the resolution of uncertainty" required>
63
- <button type="submit" class="btn btn-primary" form="searchForm">Search</button>
64
- </div>
65
- <div id="results" class="mt-4"></div>
66
- </div>
67
- </div>
68
- </form>
69
- </div>
70
-
71
- <script>
72
- const yearRange = document.getElementById('yearRange');
73
- const yearRangeValue = document.getElementById('yearRangeValue');
74
-
75
- function updateYearRangeValue() {
76
- const startYear = parseInt(yearRange.value);
77
- const endYear = 2024;
78
- yearRangeValue.textContent = `${startYear} - ${endYear}`;
79
- }
80
-
81
- yearRange.addEventListener('input', updateYearRangeValue);
82
- updateYearRangeValue(); // Initial call to set the text
83
-
84
- document.getElementById('searchForm').addEventListener('submit', async (e) => {
85
- e.preventDefault();
86
- const form = e.target;
87
-
88
- const params = new URLSearchParams();
89
- params.append('query', form.query.value);
90
- params.append('start_year', yearRange.value);
91
- params.append('end_year', '2024');
92
-
93
- const selectedVenues = Array.from(document.querySelectorAll('input[type="checkbox"]:checked'))
94
- .map(checkbox => checkbox.value);
95
- selectedVenues.forEach(venue => params.append('venue_type', venue));
96
-
97
- try {
98
- const response = await axios.get('http://localhost:8893/api/search', { params });
99
- displayResults(response.data);
100
- } catch (error) {
101
- console.error('Error:', error);
102
- document.getElementById('results').innerHTML = '<p class="alert alert-danger">An error occurred while fetching results.</p>';
103
- }
104
- });
105
-
106
- function displayResults(data) {
107
- const resultsDiv = document.getElementById('results');
108
- if (data.length === 0) {
109
- resultsDiv.innerHTML = '<p class="alert alert-info">No results found.</p>';
110
- return;
111
- }
112
-
113
- console.log(data)
114
-
115
- let html = '';
116
- data.forEach(paper => {
117
- html += `
118
- <div class="card">
119
- <div class="card-body">
120
- <h5 class="card-title"><a href="${paper.url}" target="_blank">${paper.title}</a></h5>
121
- <h6 class="card-subtitle mb-1 text-muted">${paper.authors}</h6>
122
- <p class="card-subtitle text-muted paper-metadata">
123
- <strong>${paper.year} / ${paper.venue_type}</strong> <br>
124
- ${paper.is_findings ? '<br><strong>Findings Paper</strong>' : ''}
125
- </p>
126
- <p class="card-text"><small class="text-muted">${paper.abstract}</small></p>
127
- </div>
128
- </div>
129
- `;
130
- });
131
-
132
- resultsDiv.innerHTML = html;
133
- }
134
- </script>
135
- </body>
136
-
137
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py DELETED
@@ -1,95 +0,0 @@
1
- import torch
2
- import tqdm
3
-
4
- def maxsim(pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs):
5
- ncentroids, nquery_vectors = centroid_scores.shape
6
- centroid_scores = centroid_scores.flatten()
7
- scores = []
8
-
9
- for i in tqdm.tqdm(range(len(pids)), desc='Calculating maxsim over centroids...'):
10
- seen_codes = set()
11
- per_doc_scores = torch.full((nquery_vectors,), -9999, dtype=torch.float32)
12
-
13
- pid = pids[i]
14
- for j in range(doclens[pid]):
15
- code = codes[offsets[pid] + j]
16
- assert code < ncentroids
17
- if idx[code] and code not in seen_codes:
18
- for k in range(nquery_vectors):
19
- per_doc_scores[k] = torch.max(
20
- per_doc_scores[k],
21
- centroid_scores[code * nquery_vectors + k]
22
- )
23
- seen_codes.add(code)
24
-
25
- score = torch.sum(per_doc_scores[:nquery_vectors]).item()
26
- scores += [(score, pid)]
27
-
28
- # Sort and return scores
29
- global_scores = sorted(scores, key=lambda x: x[0], reverse=True)
30
- filtered_pids = [pid for _, pid in global_scores[:nfiltered_docs]]
31
- filtered_pids = torch.tensor(filtered_pids, dtype=torch.int32)
32
-
33
- return filtered_pids
34
-
35
-
36
- def filter_pids(pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs):
37
- filtered_pids = maxsim(
38
- pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_docs
39
- )
40
-
41
- print('Stage 2 filtering:', pids.shape, '->', filtered_pids.shape) # (all_docs) -> (n_docs/4)
42
-
43
- nfinal_filtered_docs = int(nfiltered_docs / 4)
44
- ones = [True] * centroid_scores.size(0)
45
-
46
- final_filtered_pids = maxsim(
47
- filtered_pids, centroid_scores, codes, doclens, offsets, ones, nfinal_filtered_docs
48
- )
49
-
50
- print('Stage 3 filtering:', filtered_pids.shape, '->', final_filtered_pids.shape) # (n_docs) -> (n_docs/4)
51
-
52
- return final_filtered_pids
53
-
54
-
55
- def decompress_residuals(pids, doclens, offsets, bucket_weights, reversed_bit_map,
56
- bucket_weight_combinations, binary_residuals, codes,
57
- centroids, dim, nbits):
58
- npacked_vals_per_byte = 8 // nbits
59
- packed_dim = dim // npacked_vals_per_byte
60
- cumulative_lengths = [0 for _ in range(len(pids)+1)]
61
- noutputs = 0
62
- for i in range(len(pids)):
63
- noutputs += doclens[pids[i]]
64
- cumulative_lengths[i + 1] = cumulative_lengths[i] + doclens[pids[i]]
65
-
66
- output = []
67
-
68
- binary_residuals = binary_residuals.flatten()
69
- centroids = centroids.flatten()
70
-
71
- # Iterate over all documents
72
- for i in range(len(pids)):
73
- pid = pids[i]
74
-
75
- # Offset into packed list of token vectors for the given document
76
- offset = offsets[pid]
77
-
78
- # For each document, iterate over all token vectors
79
- for j in range(doclens[pid]):
80
- code = codes[offset + j]
81
-
82
- # For each token vector, iterate over the packed (8-bit) residual values
83
- for k in range(packed_dim):
84
- x = binary_residuals[(offset + j) * packed_dim + k]
85
- x = reversed_bit_map[x]
86
-
87
- # For each packed residual value, iterate over the bucket weight indices.
88
- # If we use n-bit compression, that means there will be (8 / n) indices per packed value.
89
- for l in range(npacked_vals_per_byte):
90
- output_dim_idx = k * npacked_vals_per_byte + l
91
- bucket_weight_idx = bucket_weight_combinations[x * npacked_vals_per_byte + l]
92
- output[(cumulative_lengths[i] + j) * dim + output_dim_idx] = \
93
- bucket_weights[bucket_weight_idx] + centroids[code * dim + output_dim_idx]
94
-
95
- return output