diff --git a/app.py b/app.py
index 4180787eb376c5d3c53a364292f976ed7575e684..9bc7644361ac1de5a283f0e0e36ff20dfd458b61 100644
--- a/app.py
+++ b/app.py
@@ -1,24 +1,23 @@
-# import sys
-# sys.path.insert(0, './petals/')
+import sys
+sys.path.insert(0, './petals/')
-# import torch
-# import transformers
+import torch
+import transformers
import gradio as gr
-# from src.client.remote_model import DistributedBloomForCausalLM
+from src.client.remote_model import DistributedBloomForCausalLM
-# MODEL_NAME = "bigscience/test-bloomd-6b3" # select model you like
-# INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"]
+MODEL_NAME = "bigscience/test-bloomd-6b3" # select model you like
+INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"]
-# tokenizer = transformers.BloomTokenizerFast.from_pretrained("bigscience/test-bloomd-6b3")
-# model = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3", initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32)
+tokenizer = transformers.BloomTokenizerFast.from_pretrained("bigscience/test-bloomd-6b3")
+model = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3", initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32)
def inference(text, seq_length=1):
- return text
- # input_ids = tokenizer([text], return_tensors="pt").input_ids
- # output = model.generate(input_ids, max_new_tokens=seq_length)
- # return tokenizer.batch_decode(output)[0]
+ input_ids = tokenizer([text], return_tensors="pt").input_ids
+ output = model.generate(input_ids, max_new_tokens=seq_length)
+ return tokenizer.batch_decode(output)[0]
iface = gr.Interface(
fn=inference,
diff --git a/petals/README.md b/petals/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a36322247087f82cd86920699682205ed8343683
--- /dev/null
+++ b/petals/README.md
@@ -0,0 +1,203 @@
+
+
+ Decentralized platform for running 100B+ language models
+
+
+
+
+
+
+
+
+## Key features
+
+- Run inference or fine-tune large language models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) by joining compute resources with people all over the Internet. No need to have high-end GPUs.
+- It's difficult to fit the whole BLOOM-176B into GPU memory [unless](https://twitter.com/Tim_Dettmers/status/1559892918395031552) you have multiple high-end GPUs. Instead, **Petals** allows to load and serve a small part of the model, then team up with people serving all the other parts to run inference or fine-tuning.
+- This way, one inference step takes ≈ 1 sec — much faster than possible with offloading. Enough for chatbots and other interactive apps.
+- Beyond traditional language model APIs — you can employ any fine-tuning and sampling methods by executing custom paths through the model or accessing its hidden states. This allows for the comforts of an API with the flexibility of PyTorch.
+
+
+ [Read paper] | [View website]
+
+
+## How it works?
+
+
+
+
+
+### 🛠️ Examples
+
+Petals integrates seamlessly with PyTorch and the Hugging Face [Transformers](https://github.com/huggingface/transformers) library.
+
+This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a sequence classification task via soft prompt tuning:
+
+```python
+# Initialize distributed BLOOM and connect to the swarm
+model = DistributedBloomForCausalLM.from_pretrained(
+ "bigscience/bloom-petals", tuning_mode="ptune", initial_peers=SEE_BELOW
+) # Embeddings & prompts are on your device, BLOOM blocks are distributed
+
+print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5))
+
+# Training (updates only local prompts / adapters)
+optimizer = torch.optim.AdamW(model.parameters())
+for input_ids, labels in data_loader:
+ outputs = model.forward(input_ids)
+ loss = cross_entropy(outputs.logits, labels)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+```
+
+### 🚧 This project is in active development
+
+Be careful: some features may not work, interfaces may change, and we have no detailed docs yet (see [roadmap](https://github.com/bigscience-workshop/petals/issues/12)).
+
+A stable version of the code and a public swarm open to everyone will be released in November 2022. You can [subscribe](https://petals.ml/) to be emailed when it happens or fill in [this form](https://forms.gle/TV3wtRPeHewjZ1vH9) to help the public launch by donating GPU time. In the meantime, you can launch and use your own private swarm.
+
+### 🔒 Privacy and security
+
+If you work with sensitive data, you should only use a private swarm (or a subset of servers in the public swarm) hosted by people and institutions you trust, who are authorized to process this data.
+
+This is important because it's technically possible for peers serving model layers to recover input data or model outputs. Also, if there are malicious peers, they may alter their outputs to influence the model outputs. See a more detailed discussion in Section 4 of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
+
+## FAQ
+
+1. **What's the motivation for people to host model layers in the public swarm?**
+
+ People who run inference and fine-tuning themselves get a certain speedup if they host a part of the model locally. Some may be also motivated to "give back" to the community helping them to run the model (similarly to how [BitTorrent](https://en.wikipedia.org/wiki/BitTorrent) users help others by sharing data they have already downloaded).
+
+ Since it may be not enough for everyone, we are also working on introducing explicit __incentives__ ("bloom points") for people donating their GPU time to the public swarm. Once this system is ready, people who earned these points will be able to spend them on inference/fine-tuning with higher priority or increased security guarantees, or (maybe) exchange them for other rewards.
+
+2. **Why is the platform named "Petals"?**
+
+ "Petals" is a metaphor for people serving different parts of the model. Together, they host the entire language model — [BLOOM](https://huggingface.co/bigscience/bloom).
+
+ While our platform focuses on BLOOM now, we aim to support more [foundation models](https://arxiv.org/abs/2108.07258) in future.
+
+## Installation
+
+Here's how to install the dependencies with conda:
+```
+conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
+pip install -r requirements.txt
+```
+
+This script uses Anaconda to install cuda-enabled PyTorch.
+If you don't have anaconda, you can get it from [here](https://www.anaconda.com/products/distribution).
+If you don't want anaconda, you can install PyTorch [any other way](https://pytorch.org/get-started/locally/).
+If you want to run models with 8-bit weights, please install **PyTorch with CUDA 11** or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).
+
+__OS support:__ Currently, Petals only supports Linux operating systems. On Windows 11, you can run Petals with GPU enabled inside WSL2 ([read more](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)).
+For macOS, you can *probably* run everything normally if you manage to install dependencies, but we do not guarantee this.
+
+
+## 🚀 Getting Started
+
+This is a toy example running on a local machine without GPU and with a tiny model.
+For a detailed instruction with larger models, see ["Launch your own swarm"](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm).
+
+First, run a couple of servers, each in a separate shell. To launch your first server, run:
+```bash
+python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \
+ --host_maddrs /ip4/127.0.0.1/tcp/31337 # use port 31337, local connections only
+```
+
+This server will host 8 (out of 24) blocks of a [tiny 560M version](https://huggingface.co/bloom-testing/test-bloomd-560m-main) of the BLOOM model that was converted for Petals.
+
+> If you'd like to run a swarm of servers with the full BLOOM straight away, please see [this instruction](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) (you'll need several GPUs!). To run a different model, see [this wiki page](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-PETALS).
+
+Once the server has started, it will print out a ton of information, including an important line like this:
+
+```bash
+Mon Day 01:23:45.678 [INFO] Running DHT node on ['/ip4/127.0.0.1/tcp/31337/p2p/ALongStringOfCharacters'], initial peers = []
+```
+
+You can use this address (`/ip4/whatever/else`) to connect additional servers. Open another terminal and run:
+
+```bash
+python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \
+ --host_maddrs /ip4/127.0.0.1/tcp/0 \
+ --initial_peers /ip4/127.0... # <-- TODO: Copy the address of another server here
+# e.g. --initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmS1GecIfYouAreReadingThisYouNeedToCopyYourServerAddressCBBq
+```
+
+You can assign `--initial_peers` to one or multiple addresses of other servers, not necessarily the first one.
+The only requirement is that at least one of them is running at the time.
+
+Before you proceed, __please run 3 servers__ for a total of 24 blocks (3x8). If you are running a different model,
+make sure your servers have enough total `--num_blocks` to cover that model.
+
+Once your have enough servers, you can use them to train and/or inference the model:
+```python
+import torch
+import torch.nn.functional as F
+import transformers
+from src import DistributedBloomForCausalLM
+
+initial_peers = [TODO_put_one_or_more_server_addresses_here] # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]
+tokenizer = transformers.BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")
+model = DistributedBloomForCausalLM.from_pretrained(
+ "bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
+) # this model has only embeddings / logits, all transformer blocks rely on remote servers
+
+
+inputs = tokenizer("a cat sat", return_tensors="pt")["input_ids"]
+remote_outputs = model.generate(inputs, max_length=10)
+print(tokenizer.decode(remote_outputs[0])) # "a cat sat in the back of the car,"
+
+# "train" input embeddings by backprop through distributed transformer blocks
+model.transformer.word_embeddings.weight.requires_grad = True
+outputs = model.forward(input_ids=inputs)
+loss = F.cross_entropy(outputs.logits.flatten(0, 1), inputs.flatten())
+loss.backward()
+print("Gradients (norm):", model.transformer.word_embeddings.weight.grad.norm())
+```
+
+Of course, this is a simplified code snippet. For actual training, see our example on "deep" prompt-tuning here: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb).
+
+Here's a [more advanced tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) that covers 8-bit quantization and best practices for running Petals.
+
+## 🛠️ Development
+
+Petals uses pytest with a few plugins. To install them, run `pip install -r requirements-dev.txt`
+
+To run minimalistic tests, spin up some servers:
+
+```bash
+export MODEL_NAME=bloom-testing/test-bloomd-560m-main
+export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
+python -m cli.run_server $MODEL_NAME --block_indices 0:12 --throughput 1 --torch_dtype float32 \
+ --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> server1.log &
+sleep 5 # wait for the first server to initialize DHT
+python -m cli.run_server $MODEL_NAME --block_indices 12:24 --throughput 1 --torch_dtype float32 \
+ --initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g &> server2.log &
+
+tail -f server1.log server2.log # view logs for both servers
+# after you're done, kill servers with 'pkill -f cli.run_server'
+```
+
+Then launch pytest:
+
+```
+export MODEL_NAME=bloom-testing/test-bloomd-560m-main REF_NAME=bigscience/bloom-560m
+export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
+PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v
+```
+
+The automated tests use a more complex server configuration that can be found [here](https://github.com/bigscience-workshop/petals/blob/main/.github/workflows/run-tests.yaml).
+
+### Code style
+
+We use [black](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html) and [isort](https://pycqa.github.io/isort/) for all pull requests.
+Before commiting your code, simply run `black . && isort .` and you will be fine.
+
+--------------------------------------------------------------------------------
+
+
+ This project is a part of the BigScience research workshop.
+
+
+
+
diff --git a/petals/cli/__init__.py b/petals/cli/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/petals/cli/config.json b/petals/cli/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ca7ffbbb014181a0794c524c5b12802970f1f7bf
--- /dev/null
+++ b/petals/cli/config.json
@@ -0,0 +1,20 @@
+{
+ "apply_residual_connection_post_layernorm": false,
+ "attention_dropout": 0.0,
+ "attention_softmax_in_fp32": true,
+ "bos_token_id": 1,
+ "eos_token_id": 2,
+ "hidden_dropout": 0.0,
+ "initializer_range": 0.02,
+ "layer_norm_epsilon": 1e-05,
+ "masked_softmax_fusion": true,
+ "model_type": "bloom",
+ "n_embed": 14336,
+ "n_layer": 70,
+ "num_attention_heads": 112,
+ "pretraining_tp": 4,
+ "slow_but_exact": false,
+ "transformers_version": "4.20.0.dev0",
+ "use_cache": true,
+ "vocab_size": 250880
+}
\ No newline at end of file
diff --git a/petals/cli/convert_model.py b/petals/cli/convert_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..91f3bd848969e43b873236994a28843631ef6046
--- /dev/null
+++ b/petals/cli/convert_model.py
@@ -0,0 +1,93 @@
+import argparse
+import os
+
+import psutil
+import torch.backends.quantized
+import torch.nn as nn
+import transformers
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from huggingface_hub import Repository
+from tqdm.auto import tqdm
+
+from src import BloomModel
+from src.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
+from src.client import DistributedBloomConfig
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
+
+ parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
+ parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
+ parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
+ parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
+ parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
+ parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
+ parser.add_argument(
+ "--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
+ )
+ parser.add_argument(
+ "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
+ )
+ parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
+ parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
+ args = parser.parse_args()
+
+ free_ram_gb = psutil.virtual_memory().available / 2**30
+ if args.model == "bigscience/bloom" and free_ram_gb < 400:
+ logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
+
+ assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
+ if os.path.exists(args.output_path) and (
+ len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
+ ):
+ raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
+
+ logger.info(f"Loading source model {args.model} (this may take a few minutes)")
+ config = DistributedBloomConfig.from_pretrained(
+ args.model, use_auth_token=args.use_auth_token, revision=args.revision
+ )
+ config.dht_prefix = args.output_repo
+
+ model = BloomModel.from_pretrained(
+ args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
+ )
+ if args.resize_token_embeddings:
+ logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
+ model.resize_token_embeddings(args.resize_token_embeddings)
+ config.vocab_size = args.resize_token_embeddings
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ args.model, use_auth_token=args.use_auth_token, revision=args.revision
+ )
+ os.makedirs(args.output_path, exist_ok=True)
+
+ repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
+ repo.git_pull()
+
+ transformer_blocks = model.h
+ logger.info(
+ f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
+ f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
+ )
+ for i, block in enumerate(tqdm(transformer_blocks)):
+ repo.git_checkout(args.client_branch, create_branch_ok=True)
+ with repo.commit(
+ commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
+ ):
+ torch.save(block.state_dict(), "./pytorch_model.bin")
+
+ logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
+ repo.git_checkout(args.client_branch, create_branch_ok=True)
+ with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
+ model.h = nn.ModuleList()
+ model.save_pretrained(".")
+ tokenizer.save_pretrained(".")
+ config.save_pretrained(".")
+
+ logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
diff --git a/petals/cli/deploy_server.sh b/petals/cli/deploy_server.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1b0cc24145b404d556d28ff5394fb525ddbc2dde
--- /dev/null
+++ b/petals/cli/deploy_server.sh
@@ -0,0 +1,79 @@
+#!/usr/bin/env bash
+
+#################
+# Parse options #
+#################
+
+instructions() {
+ echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
+ echo " -m: model name"
+ echo " -i: initial peer"
+ echo " -d: device" >&2
+ echo " -p: server identity path" >&2
+ echo " -b: block_ids" >&2
+ echo " -a: host maddrs" >&2
+ echo " -t: whether to run local tests" >&2
+ exit 1
+}
+
+if [ ! $# -ge 8 ]; then
+ instructions
+fi
+
+while getopts ":m:i:d:p:b:a:t:" option; do
+ case $option in
+ m) MODEL_NAME=${OPTARG}
+ ;;
+ i) INITIAL_PEER=${OPTARG}
+ ;;
+ d) DEVICE=${OPTARG}
+ ;;
+ p) SERVER_ID_PATH=${OPTARG}
+ ;;
+ b) BLOCK_IDS=${OPTARG}
+ ;;
+ a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
+ ;;
+ t) RUN_LOCAL_TESTS=true
+ ;;
+ \?) instructions
+ ;;
+ esac
+done
+
+
+echo "=========="
+echo "= Config ="
+echo "=========="
+echo "Model name: ${MODEL_NAME}"
+echo "Initial peer: ${INITIAL_PEER}"
+echo "Device: ${DEVICE}"
+echo "Server name: ${SERVER_ID_PATH}"
+echo "Server address: ${HOST_MADDR}"
+echo "Bloom blocks: ${BLOCK_IDS}"
+
+
+###########################
+# Install or activate env #
+###########################
+
+# TODO fix bug with self calling
+source ~/miniconda3/etc/profile.d/conda.sh
+if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
+ conda activate bloom-demo
+else
+ conda create -y --name bloom-demo python=3.8.12 pip
+ conda activate bloom-demo
+
+ conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
+ pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+ pip install -i https://pypi.org/simple -r requirements.txt
+ pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
+fi
+
+##############
+# Run server #
+##############
+
+python -m cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
+ --block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log
diff --git a/petals/cli/inference_one_block.py b/petals/cli/inference_one_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bde04f7da562288f243dd9670a77af72a70c01f
--- /dev/null
+++ b/petals/cli/inference_one_block.py
@@ -0,0 +1,53 @@
+import argparse
+
+import torch
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from tqdm.auto import trange
+
+from src.bloom.block import BloomBlock
+from src.bloom.model import BloomConfig
+from src.bloom.ops import build_alibi_tensor
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+logger.warning("inference_one_block will soon be deprecated in favour of tests!")
+
+
+def print_device_info(device=None):
+ """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
+ device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
+ logger.info(f"Using device: {device}")
+
+ # Additional Info when using cuda
+ if device.type == "cuda":
+ logger.info(torch.cuda.get_device_name(0))
+ logger.info(f"Memory Usage:")
+ logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
+ logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
+ parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
+ parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
+ parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
+ parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
+ parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
+ args = parser.parse_args()
+
+ if args.device is None:
+ args.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ config = BloomConfig.from_json_file(args.config)
+ block = BloomBlock(config, args.layer_index).to(args.device)
+
+ cache = None
+
+ for i in trange(args.num_steps):
+ dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
+ alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
+ with torch.no_grad():
+ outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
+
+ print_device_info(args.device)
diff --git a/petals/cli/local_server_config_example.cfg b/petals/cli/local_server_config_example.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..8cbfe458d2155655ecf8ad97fd6fe32e51c9ca47
--- /dev/null
+++ b/petals/cli/local_server_config_example.cfg
@@ -0,0 +1,5 @@
+device=cpu
+block_ids=2:3
+id_path=./server.id
+maddr=/ip4/127.0.0.1/tcp/30000
+#
diff --git a/petals/cli/remote_server_config_example.cfg b/petals/cli/remote_server_config_example.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..54df7afbec9d14987d57394e0e2d24712f488327
--- /dev/null
+++ b/petals/cli/remote_server_config_example.cfg
@@ -0,0 +1,6 @@
+name=bloom-peer-0.bloom.net
+device=cpu
+block_ids=1:3
+id_path=./server.id
+maddr=/ip4/0.0.0.0/tcp/30000
+#
\ No newline at end of file
diff --git a/petals/cli/run_local_servers.sh b/petals/cli/run_local_servers.sh
new file mode 100644
index 0000000000000000000000000000000000000000..51a802a4c4a97fdc5adf55ef630632a3334f162d
--- /dev/null
+++ b/petals/cli/run_local_servers.sh
@@ -0,0 +1,109 @@
+# !/usr/bin/env bash
+
+#################
+# Parse options #
+#################
+
+instructions() {
+ echo "Usage: $0 [-n] [-c]" >&2
+ echo " -n: number of servers to run" >&2
+ echo " -c: path to the server configs" >&2
+ exit 1
+}
+
+if [ $# != 4 ]; then
+ instructions
+fi
+
+while getopts ":n:c:t:" option; do
+ case $option in
+ n) NUM_SERVERS=${OPTARG}
+ ;;
+ c) CONFIG_PATH=${OPTARG}
+ ;;
+ \?) instructions
+ ;;
+ esac
+done
+
+
+###########################
+# Install or activate env #
+###########################
+
+source ~/miniconda3/etc/profile.d/conda.sh
+if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
+ conda activate bloom-demo
+else
+ conda create -y --name bloom-demo python=3.8.12 pip
+ conda activate bloom-demo
+
+ conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
+ pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+ pip install -i https://pypi.org/simple -r requirements.txt
+ pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
+fi
+
+
+#######################
+# Create Initial peer #
+#######################
+
+hivemind-dht &> tmp.out &
+sleep 5
+INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
+echo "Initial peer: ${INITIAL_PEER}"
+
+
+##############################
+# Initialize the config file #
+##############################
+
+typeset -A cfg
+cfg=( # set default values in config array
+ [device]="cpu"
+ [block_ids]="1:2"
+ [id_path]="server.id"
+ [maddr]="/ip4/127.0.0.1/tcp/30000"
+)
+
+###############
+# Run servers #
+###############
+
+for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
+do
+ ###############
+ # Read config #
+ ###############
+
+ while read line
+ do
+ if echo $line | grep -F = &>/dev/null
+ then
+ varname=$(echo "$line" | cut -d '=' -f 1)
+ cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
+ fi
+ done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
+
+ echo "=== Server #${SERVER_ID} ==="
+ echo "Server ID: ${cfg[id_path]}"
+ echo "Device: ${cfg[device]}"
+ echo "Bloom block ids: ${cfg[block_ids]}"
+ echo "Host maddr: ${cfg[maddr]}"
+ echo ""
+
+ ##############
+ # Run server #
+ ##############
+
+ tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
+done
+
+#####################
+# Kill initial peer #
+#####################
+
+sleep 10
+pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
+rm tmp.out
\ No newline at end of file
diff --git a/petals/cli/run_remote_servers.sh b/petals/cli/run_remote_servers.sh
new file mode 100644
index 0000000000000000000000000000000000000000..43df8d8b1926afea8591a77b0d88db75e410568e
--- /dev/null
+++ b/petals/cli/run_remote_servers.sh
@@ -0,0 +1,110 @@
+# !/usr/bin/env bash
+
+SSH_KEY_PATH="~/.ssh/"
+
+#################
+# Parse options #
+#################
+
+instructions() {
+ echo "Usage: $0 [-u] [-n] [-c]" >&2
+ echo " -u: username" >&2
+ echo " -n: number of servers to run" >&2
+ echo " -c: path to the server configs" >&2
+ exit 1
+}
+
+if [ $# != 6 ]; then
+ instructions
+fi
+
+while getopts ":u:n:c:" option; do
+ case $option in
+ u) USERNAME=${OPTARG}
+ ;;
+ n) NUM_SERVERS=${OPTARG}
+ ;;
+ c) CONFIG_PATH=${OPTARG}
+ ;;
+ \?) instructions
+ ;;
+ esac
+done
+
+
+###########################
+# Install or activate env #
+###########################
+
+source ~/miniconda3/etc/profile.d/conda.sh
+if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
+ conda activate bloom-demo
+else
+ conda create -y --name bloom-demo python=3.8.12 pip
+ conda activate bloom-demo
+
+ conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
+ pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+ pip install -i https://pypi.org/simple -r requirements.txt
+fi
+
+
+#######################
+# Create Initial peer #
+#######################
+
+hivemind-dht &> tmp.out &
+
+sleep 5
+INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
+rm tmp.out
+echo "Initial peer: ${INITIAL_PEER}"
+
+
+##############################
+# Initialize the config file #
+##############################
+
+typeset -A cfg
+cfg=( # set default values in config array
+ [name]=""
+ [device]="cpu"
+ [block_ids]="1:2"
+ [id_path]="server.id"
+ [maddr]="/ip4/0.0.0.0/tcp/30000"
+)
+
+###############
+# Run servers #
+###############
+
+for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
+do
+ ###############
+ # Read config #
+ ###############
+
+ while read line
+ do
+ if echo $line | grep -F = &>/dev/null
+ then
+ varname=$(echo "$line" | cut -d '=' -f 1)
+ cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
+ fi
+ done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
+
+ SERVER_NAME="${USERNAME}@${cfg[name]}"
+ echo "=== Server #${SERVER_ID} ==="
+ echo "Server name ${SERVER_NAME}"
+ echo "Server ID: ${cfg[id_path]}"
+ echo "Device: ${cfg[device]}"
+ echo "Bloom block ids: ${cfg[block_ids]}"
+ echo "Host maddr: ${cfg[maddr]}"
+ echo "================="
+
+ ##############
+ # Run server #
+ ##############
+
+ ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
+done
\ No newline at end of file
diff --git a/petals/cli/run_server.py b/petals/cli/run_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcef351d108cafceefe0d885f3a49c90582fa7fb
--- /dev/null
+++ b/petals/cli/run_server.py
@@ -0,0 +1,129 @@
+import argparse
+
+import configargparse
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.limits import increase_file_limit
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from humanfriendly import parse_size
+
+from src.server.server import Server
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+def main():
+ # fmt:off
+ parser = configargparse.ArgParser(default_config_files=["config.yml"],
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
+
+ group = parser.add_mutually_exclusive_group(required=True)
+ group.add_argument('--converted_model_name_or_path', type=str, default=None,
+ help="path or name of a pretrained model, converted with cli/convert_model.py")
+ group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
+
+ parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
+ parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
+ parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
+ "use the same name as in the converted model.")
+ parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
+ help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
+ parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
+ help='Visible multiaddrs the host announces for external connections from other p2p instances')
+
+ parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
+
+ parser.add_argument('--num_handlers', type=int, default=8, required=False,
+ help='server will use this many processes to handle incoming requests')
+ parser.add_argument('--min_batch_size', type=int, default=1,
+ help='Minimum required batch size for all operations (in total tokens)')
+ parser.add_argument('--max_batch_size', type=int, default=16384,
+ help='The total number of tokens in the same batch will not exceed this value')
+ parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
+ help='Pre-form this many subsequent batches while GPU is processing the current one')
+ parser.add_argument('--sender_threads', type=int, default=1, required=False,
+ help='Use this many threads to pass results/exceptions from Runtime to Pools')
+ parser.add_argument('--inference_max_length', type=int, default=16384,
+ help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
+ parser.add_argument('--cache_dir', type=str, default=None,
+ help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
+ parser.add_argument('--device', type=str, default=None, required=False,
+ help='all blocks will use this device in torch notation; default: cuda if available else cpu')
+ parser.add_argument("--torch_dtype", type=str, default="auto",
+ help="Use this dtype to store block weights and do computations. "
+ "By default, respect the dtypes in the pre-trained state dict.")
+ parser.add_argument('--attn_cache_size', type=str, default=None,
+ help='The size of GPU memory allocated for storing past attention keys/values between inference'
+ ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
+ parser.add_argument('--revision', type=str, default='main',
+ help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
+ "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
+
+ parser.add_argument('--throughput',
+ type=lambda value: value if value in ['auto', 'eval'] else float(value),
+ default='auto',
+ help='Expected server throughput (a float measured in RPS). '
+ 'If set to "auto" (default), the script evaluates network and compute throughput '
+ 'on the first run and uses these estimates for future runs. '
+ 'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
+ parser.add_argument('--update_period', type=float, required=False, default=30,
+ help='Server will report blocks to DHT once in this many seconds')
+ parser.add_argument('--expiration', type=float, required=False, default=None,
+ help='DHT entries will expire after this many seconds')
+ parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
+ help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
+ parser.add_argument('--increase_file_limit', action='store_true',
+ help='On *nix, this will increase the max number of processes '
+ 'a server can spawn before hitting "Too many open files"; Use at your own risk.')
+ parser.add_argument('--stats_report_interval', type=int, required=False,
+ help='Interval between two reports of batch processing performance statistics')
+
+ parser.add_argument('--custom_module_path', type=str, required=False,
+ help='Path of a file with custom nn.modules, wrapped into special decorator')
+ parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
+
+ parser.add_argument("--balance_quality", type=float, default=0.75,
+ help="Rebalance the swarm if its throughput is worse than this share of the optimal "
+ "throughput. Use 0.0 to disable rebalancing, values > 1.0 to force rebalancing "
+ "on each check for debugging purposes.")
+ parser.add_argument("--mean_balance_check_period", type=float, default=60,
+ help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
+
+ parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
+ parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
+
+ # fmt:on
+ args = vars(parser.parse_args())
+ args.pop("config", None)
+
+ args["converted_model_name_or_path"] = args.pop("model") or args["converted_model_name_or_path"]
+
+ if args.pop("increase_file_limit"):
+ increase_file_limit()
+
+ compression_type = args.pop("compression").upper()
+ compression = getattr(CompressionType, compression_type)
+
+ attn_cache_size = args.pop("attn_cache_size")
+ if attn_cache_size is not None:
+ attn_cache_size = parse_size(attn_cache_size)
+ assert isinstance(
+ attn_cache_size, (int, type(None))
+ ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
+
+ use_auth_token = args.pop("use_auth_token")
+ args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
+
+ server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True)
+
+ try:
+ server.join()
+ except KeyboardInterrupt:
+ logger.info("Caught KeyboardInterrupt, shutting down")
+ finally:
+ server.shutdown()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/petals/cli/speed_test.py b/petals/cli/speed_test.py
new file mode 100755
index 0000000000000000000000000000000000000000..d3342c312f342d33e743ed7b300d4147403be17c
--- /dev/null
+++ b/petals/cli/speed_test.py
@@ -0,0 +1,1941 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2012 Matt Martz
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import csv
+import datetime
+import errno
+import math
+import os
+import platform
+import re
+import signal
+import socket
+import sys
+import threading
+import timeit
+import xml.parsers.expat
+
+try:
+ import gzip
+
+ GZIP_BASE = gzip.GzipFile
+except ImportError:
+ gzip = None
+ GZIP_BASE = object
+
+__version__ = "2.1.4b1"
+
+
+class FakeShutdownEvent(object):
+ """Class to fake a threading.Event.isSet so that users of this module
+ are not required to register their own threading.Event()
+ """
+
+ @staticmethod
+ def isSet():
+ "Dummy method to always return false" ""
+ return False
+
+ is_set = isSet
+
+
+# Some global variables we use
+DEBUG = False
+_GLOBAL_DEFAULT_TIMEOUT = object()
+PY25PLUS = sys.version_info[:2] >= (2, 5)
+PY26PLUS = sys.version_info[:2] >= (2, 6)
+PY32PLUS = sys.version_info[:2] >= (3, 2)
+PY310PLUS = sys.version_info[:2] >= (3, 10)
+
+# Begin import game to handle Python 2 and Python 3
+try:
+ import json
+except ImportError:
+ try:
+ import simplejson as json
+ except ImportError:
+ json = None
+
+try:
+ import xml.etree.ElementTree as ET
+
+ try:
+ from xml.etree.ElementTree import _Element as ET_Element
+ except ImportError:
+ pass
+except ImportError:
+ from xml.dom import minidom as DOM
+ from xml.parsers.expat import ExpatError
+
+ ET = None
+
+try:
+ from urllib2 import (
+ AbstractHTTPHandler,
+ HTTPDefaultErrorHandler,
+ HTTPError,
+ HTTPErrorProcessor,
+ HTTPRedirectHandler,
+ OpenerDirector,
+ ProxyHandler,
+ Request,
+ URLError,
+ urlopen,
+ )
+except ImportError:
+ from urllib.request import (
+ AbstractHTTPHandler,
+ HTTPDefaultErrorHandler,
+ HTTPError,
+ HTTPErrorProcessor,
+ HTTPRedirectHandler,
+ OpenerDirector,
+ ProxyHandler,
+ Request,
+ URLError,
+ urlopen,
+ )
+
+try:
+ from httplib import BadStatusLine, HTTPConnection
+except ImportError:
+ from http.client import BadStatusLine, HTTPConnection
+
+try:
+ from httplib import HTTPSConnection
+except ImportError:
+ try:
+ from http.client import HTTPSConnection
+ except ImportError:
+ HTTPSConnection = None
+
+try:
+ from httplib import FakeSocket
+except ImportError:
+ FakeSocket = None
+
+try:
+ from Queue import Queue
+except ImportError:
+ from queue import Queue
+
+try:
+ from urlparse import urlparse
+except ImportError:
+ from urllib.parse import urlparse
+
+try:
+ from urlparse import parse_qs
+except ImportError:
+ try:
+ from urllib.parse import parse_qs
+ except ImportError:
+ from cgi import parse_qs
+
+try:
+ from hashlib import md5
+except ImportError:
+ from md5 import md5
+
+try:
+ from argparse import SUPPRESS as ARG_SUPPRESS, ArgumentParser as ArgParser
+
+ PARSER_TYPE_INT = int
+ PARSER_TYPE_STR = str
+ PARSER_TYPE_FLOAT = float
+except ImportError:
+ from optparse import SUPPRESS_HELP as ARG_SUPPRESS, OptionParser as ArgParser
+
+ PARSER_TYPE_INT = "int"
+ PARSER_TYPE_STR = "string"
+ PARSER_TYPE_FLOAT = "float"
+
+try:
+ from cStringIO import StringIO
+
+ BytesIO = None
+except ImportError:
+ try:
+ from StringIO import StringIO
+
+ BytesIO = None
+ except ImportError:
+ from io import BytesIO, StringIO
+
+try:
+ import __builtin__
+except ImportError:
+ import builtins
+ from io import FileIO, TextIOWrapper
+
+ class _Py3Utf8Output(TextIOWrapper):
+ """UTF-8 encoded wrapper around stdout for py3, to override
+ ASCII stdout
+ """
+
+ def __init__(self, f, **kwargs):
+ buf = FileIO(f.fileno(), "w")
+ super(_Py3Utf8Output, self).__init__(buf, encoding="utf8", errors="strict")
+
+ def write(self, s):
+ super(_Py3Utf8Output, self).write(s)
+ self.flush()
+
+ _py3_print = getattr(builtins, "print")
+ try:
+ _py3_utf8_stdout = _Py3Utf8Output(sys.stdout)
+ _py3_utf8_stderr = _Py3Utf8Output(sys.stderr)
+ except OSError:
+ # sys.stdout/sys.stderr is not a compatible stdout/stderr object
+ # just use it and hope things go ok
+ _py3_utf8_stdout = sys.stdout
+ _py3_utf8_stderr = sys.stderr
+
+ def to_utf8(v):
+ """No-op encode to utf-8 for py3"""
+ return v
+
+ def print_(*args, **kwargs):
+ """Wrapper function for py3 to print, with a utf-8 encoded stdout"""
+ if kwargs.get("file") == sys.stderr:
+ kwargs["file"] = _py3_utf8_stderr
+ else:
+ kwargs["file"] = kwargs.get("file", _py3_utf8_stdout)
+ _py3_print(*args, **kwargs)
+
+else:
+ del __builtin__
+
+ def to_utf8(v):
+ """Encode value to utf-8 if possible for py2"""
+ try:
+ return v.encode("utf8", "strict")
+ except AttributeError:
+ return v
+
+ def print_(*args, **kwargs):
+ """The new-style print function for Python 2.4 and 2.5.
+
+ Taken from https://pypi.python.org/pypi/six/
+
+ Modified to set encoding to UTF-8 always, and to flush after write
+ """
+ fp = kwargs.pop("file", sys.stdout)
+ if fp is None:
+ return
+
+ def write(data):
+ if not isinstance(data, basestring):
+ data = str(data)
+ # If the file has an encoding, encode unicode with it.
+ encoding = "utf8" # Always trust UTF-8 for output
+ if isinstance(fp, file) and isinstance(data, unicode) and encoding is not None:
+ errors = getattr(fp, "errors", None)
+ if errors is None:
+ errors = "strict"
+ data = data.encode(encoding, errors)
+ fp.write(data)
+ fp.flush()
+
+ want_unicode = False
+ sep = kwargs.pop("sep", None)
+ if sep is not None:
+ if isinstance(sep, unicode):
+ want_unicode = True
+ elif not isinstance(sep, str):
+ raise TypeError("sep must be None or a string")
+ end = kwargs.pop("end", None)
+ if end is not None:
+ if isinstance(end, unicode):
+ want_unicode = True
+ elif not isinstance(end, str):
+ raise TypeError("end must be None or a string")
+ if kwargs:
+ raise TypeError("invalid keyword arguments to print()")
+ if not want_unicode:
+ for arg in args:
+ if isinstance(arg, unicode):
+ want_unicode = True
+ break
+ if want_unicode:
+ newline = unicode("\n")
+ space = unicode(" ")
+ else:
+ newline = "\n"
+ space = " "
+ if sep is None:
+ sep = space
+ if end is None:
+ end = newline
+ for i, arg in enumerate(args):
+ if i:
+ write(sep)
+ write(arg)
+ write(end)
+
+
+# Exception "constants" to support Python 2 through Python 3
+try:
+ import ssl
+
+ try:
+ CERT_ERROR = (ssl.CertificateError,)
+ except AttributeError:
+ CERT_ERROR = tuple()
+
+ HTTP_ERRORS = (HTTPError, URLError, socket.error, ssl.SSLError, BadStatusLine) + CERT_ERROR
+except ImportError:
+ ssl = None
+ HTTP_ERRORS = (HTTPError, URLError, socket.error, BadStatusLine)
+
+if PY32PLUS:
+ etree_iter = ET.Element.iter
+elif PY25PLUS:
+ etree_iter = ET_Element.getiterator
+
+if PY26PLUS:
+ thread_is_alive = threading.Thread.is_alive
+else:
+ thread_is_alive = threading.Thread.isAlive
+
+
+def event_is_set(event):
+ try:
+ return event.is_set()
+ except AttributeError:
+ return event.isSet()
+
+
+class SpeedtestException(Exception):
+ """Base exception for this module"""
+
+
+class SpeedtestCLIError(SpeedtestException):
+ """Generic exception for raising errors during CLI operation"""
+
+
+class SpeedtestHTTPError(SpeedtestException):
+ """Base HTTP exception for this module"""
+
+
+class SpeedtestConfigError(SpeedtestException):
+ """Configuration XML is invalid"""
+
+
+class SpeedtestServersError(SpeedtestException):
+ """Servers XML is invalid"""
+
+
+class ConfigRetrievalError(SpeedtestHTTPError):
+ """Could not retrieve config.php"""
+
+
+class ServersRetrievalError(SpeedtestHTTPError):
+ """Could not retrieve speedtest-servers.php"""
+
+
+class InvalidServerIDType(SpeedtestException):
+ """Server ID used for filtering was not an integer"""
+
+
+class NoMatchedServers(SpeedtestException):
+ """No servers matched when filtering"""
+
+
+class SpeedtestMiniConnectFailure(SpeedtestException):
+ """Could not connect to the provided speedtest mini server"""
+
+
+class InvalidSpeedtestMiniServer(SpeedtestException):
+ """Server provided as a speedtest mini server does not actually appear
+ to be a speedtest mini server
+ """
+
+
+class ShareResultsConnectFailure(SpeedtestException):
+ """Could not connect to speedtest.net API to POST results"""
+
+
+class ShareResultsSubmitFailure(SpeedtestException):
+ """Unable to successfully POST results to speedtest.net API after
+ connection
+ """
+
+
+class SpeedtestUploadTimeout(SpeedtestException):
+ """testlength configuration reached during upload
+ Used to ensure the upload halts when no additional data should be sent
+ """
+
+
+class SpeedtestBestServerFailure(SpeedtestException):
+ """Unable to determine best server"""
+
+
+class SpeedtestMissingBestServer(SpeedtestException):
+ """get_best_server not called or not able to determine best server"""
+
+
+def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None):
+ """Connect to *address* and return the socket object.
+
+ Convenience function. Connect to *address* (a 2-tuple ``(host,
+ port)``) and return the socket object. Passing the optional
+ *timeout* parameter will set the timeout on the socket instance
+ before attempting to connect. If no *timeout* is supplied, the
+ global default timeout setting returned by :func:`getdefaulttimeout`
+ is used. If *source_address* is set it must be a tuple of (host, port)
+ for the socket to bind as a source address before making the connection.
+ An host of '' or port 0 tells the OS to use the default.
+
+ Largely vendored from Python 2.7, modified to work with Python 2.4
+ """
+
+ host, port = address
+ err = None
+ for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
+ af, socktype, proto, canonname, sa = res
+ sock = None
+ try:
+ sock = socket.socket(af, socktype, proto)
+ if timeout is not _GLOBAL_DEFAULT_TIMEOUT:
+ sock.settimeout(float(timeout))
+ if source_address:
+ sock.bind(source_address)
+ sock.connect(sa)
+ return sock
+
+ except socket.error:
+ err = get_exception()
+ if sock is not None:
+ sock.close()
+
+ if err is not None:
+ raise err
+ else:
+ raise socket.error("getaddrinfo returns an empty list")
+
+
+class SpeedtestHTTPConnection(HTTPConnection):
+ """Custom HTTPConnection to support source_address across
+ Python 2.4 - Python 3
+ """
+
+ def __init__(self, *args, **kwargs):
+ source_address = kwargs.pop("source_address", None)
+ timeout = kwargs.pop("timeout", 10)
+
+ self._tunnel_host = None
+
+ HTTPConnection.__init__(self, *args, **kwargs)
+
+ self.source_address = source_address
+ self.timeout = timeout
+
+ def connect(self):
+ """Connect to the host and port specified in __init__."""
+ try:
+ self.sock = socket.create_connection((self.host, self.port), self.timeout, self.source_address)
+ except (AttributeError, TypeError):
+ self.sock = create_connection((self.host, self.port), self.timeout, self.source_address)
+
+ if self._tunnel_host:
+ self._tunnel()
+
+
+if HTTPSConnection:
+
+ class SpeedtestHTTPSConnection(HTTPSConnection):
+ """Custom HTTPSConnection to support source_address across
+ Python 2.4 - Python 3
+ """
+
+ default_port = 443
+
+ def __init__(self, *args, **kwargs):
+ source_address = kwargs.pop("source_address", None)
+ timeout = kwargs.pop("timeout", 10)
+
+ self._tunnel_host = None
+
+ HTTPSConnection.__init__(self, *args, **kwargs)
+
+ self.timeout = timeout
+ self.source_address = source_address
+
+ def connect(self):
+ "Connect to a host on a given (SSL) port."
+ try:
+ self.sock = socket.create_connection((self.host, self.port), self.timeout, self.source_address)
+ except (AttributeError, TypeError):
+ self.sock = create_connection((self.host, self.port), self.timeout, self.source_address)
+
+ if self._tunnel_host:
+ self._tunnel()
+
+ if ssl:
+ try:
+ kwargs = {}
+ if hasattr(ssl, "SSLContext"):
+ if self._tunnel_host:
+ kwargs["server_hostname"] = self._tunnel_host
+ else:
+ kwargs["server_hostname"] = self.host
+ self.sock = self._context.wrap_socket(self.sock, **kwargs)
+ except AttributeError:
+ self.sock = ssl.wrap_socket(self.sock)
+ try:
+ self.sock.server_hostname = self.host
+ except AttributeError:
+ pass
+ elif FakeSocket:
+ # Python 2.4/2.5 support
+ try:
+ self.sock = FakeSocket(self.sock, socket.ssl(self.sock))
+ except AttributeError:
+ raise SpeedtestException("This version of Python does not support HTTPS/SSL " "functionality")
+ else:
+ raise SpeedtestException("This version of Python does not support HTTPS/SSL " "functionality")
+
+
+def _build_connection(connection, source_address, timeout, context=None):
+ """Cross Python 2.4 - Python 3 callable to build an ``HTTPConnection`` or
+ ``HTTPSConnection`` with the args we need
+
+ Called from ``http(s)_open`` methods of ``SpeedtestHTTPHandler`` or
+ ``SpeedtestHTTPSHandler``
+ """
+
+ def inner(host, **kwargs):
+ kwargs.update({"source_address": source_address, "timeout": timeout})
+ if context:
+ kwargs["context"] = context
+ return connection(host, **kwargs)
+
+ return inner
+
+
+class SpeedtestHTTPHandler(AbstractHTTPHandler):
+ """Custom ``HTTPHandler`` that can build a ``HTTPConnection`` with the
+ args we need for ``source_address`` and ``timeout``
+ """
+
+ def __init__(self, debuglevel=0, source_address=None, timeout=10):
+ AbstractHTTPHandler.__init__(self, debuglevel)
+ self.source_address = source_address
+ self.timeout = timeout
+
+ def http_open(self, req):
+ return self.do_open(_build_connection(SpeedtestHTTPConnection, self.source_address, self.timeout), req)
+
+ http_request = AbstractHTTPHandler.do_request_
+
+
+class SpeedtestHTTPSHandler(AbstractHTTPHandler):
+ """Custom ``HTTPSHandler`` that can build a ``HTTPSConnection`` with the
+ args we need for ``source_address`` and ``timeout``
+ """
+
+ def __init__(self, debuglevel=0, context=None, source_address=None, timeout=10):
+ AbstractHTTPHandler.__init__(self, debuglevel)
+ self._context = context
+ self.source_address = source_address
+ self.timeout = timeout
+
+ def https_open(self, req):
+ return self.do_open(
+ _build_connection(
+ SpeedtestHTTPSConnection,
+ self.source_address,
+ self.timeout,
+ context=self._context,
+ ),
+ req,
+ )
+
+ https_request = AbstractHTTPHandler.do_request_
+
+
+def build_opener(source_address=None, timeout=10):
+ """Function similar to ``urllib2.build_opener`` that will build
+ an ``OpenerDirector`` with the explicit handlers we want,
+ ``source_address`` for binding, ``timeout`` and our custom
+ `User-Agent`
+ """
+
+ printer("Timeout set to %d" % timeout, debug=True)
+
+ if source_address:
+ source_address_tuple = (source_address, 0)
+ printer("Binding to source address: %r" % (source_address_tuple,), debug=True)
+ else:
+ source_address_tuple = None
+
+ handlers = [
+ ProxyHandler(),
+ SpeedtestHTTPHandler(source_address=source_address_tuple, timeout=timeout),
+ SpeedtestHTTPSHandler(source_address=source_address_tuple, timeout=timeout),
+ HTTPDefaultErrorHandler(),
+ HTTPRedirectHandler(),
+ HTTPErrorProcessor(),
+ ]
+
+ opener = OpenerDirector()
+ opener.addheaders = [("User-agent", build_user_agent())]
+
+ for handler in handlers:
+ opener.add_handler(handler)
+
+ return opener
+
+
+class GzipDecodedResponse(GZIP_BASE):
+ """A file-like object to decode a response encoded with the gzip
+ method, as described in RFC 1952.
+
+ Largely copied from ``xmlrpclib``/``xmlrpc.client`` and modified
+ to work for py2.4-py3
+ """
+
+ def __init__(self, response):
+ # response doesn't support tell() and read(), required by
+ # GzipFile
+ if not gzip:
+ raise SpeedtestHTTPError("HTTP response body is gzip encoded, " "but gzip support is not available")
+ IO = BytesIO or StringIO
+ self.io = IO()
+ while 1:
+ chunk = response.read(1024)
+ if len(chunk) == 0:
+ break
+ self.io.write(chunk)
+ self.io.seek(0)
+ gzip.GzipFile.__init__(self, mode="rb", fileobj=self.io)
+
+ def close(self):
+ try:
+ gzip.GzipFile.close(self)
+ finally:
+ self.io.close()
+
+
+def get_exception():
+ """Helper function to work with py2.4-py3 for getting the current
+ exception in a try/except block
+ """
+ return sys.exc_info()[1]
+
+
+def distance(origin, destination):
+ """Determine distance between 2 sets of [lat,lon] in km"""
+
+ lat1, lon1 = origin
+ lat2, lon2 = destination
+ radius = 6371 # km
+
+ dlat = math.radians(lat2 - lat1)
+ dlon = math.radians(lon2 - lon1)
+ a = math.sin(dlat / 2) * math.sin(dlat / 2) + math.cos(math.radians(lat1)) * math.cos(
+ math.radians(lat2)
+ ) * math.sin(dlon / 2) * math.sin(dlon / 2)
+ c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
+ d = radius * c
+
+ return d
+
+
+def build_user_agent():
+ """Build a Mozilla/5.0 compatible User-Agent string"""
+
+ ua_tuple = (
+ "Mozilla/5.0",
+ "(%s; U; %s; en-us)" % (platform.platform(), platform.architecture()[0]),
+ "Python/%s" % platform.python_version(),
+ "(KHTML, like Gecko)",
+ "speedtest-cli/%s" % __version__,
+ )
+ user_agent = " ".join(ua_tuple)
+ printer("User-Agent: %s" % user_agent, debug=True)
+ return user_agent
+
+
+def build_request(url, data=None, headers=None, bump="0", secure=False):
+ """Build a urllib2 request object
+
+ This function automatically adds a User-Agent header to all requests
+
+ """
+
+ if not headers:
+ headers = {}
+
+ if url[0] == ":":
+ scheme = ("http", "https")[bool(secure)]
+ schemed_url = "%s%s" % (scheme, url)
+ else:
+ schemed_url = url
+
+ if "?" in url:
+ delim = "&"
+ else:
+ delim = "?"
+
+ # WHO YOU GONNA CALL? CACHE BUSTERS!
+ final_url = "%s%sx=%s.%s" % (schemed_url, delim, int(timeit.time.time() * 1000), bump)
+
+ headers.update(
+ {
+ "Cache-Control": "no-cache",
+ }
+ )
+
+ printer("%s %s" % (("GET", "POST")[bool(data)], final_url), debug=True)
+
+ return Request(final_url, data=data, headers=headers)
+
+
+def catch_request(request, opener=None):
+ """Helper function to catch common exceptions encountered when
+ establishing a connection with a HTTP/HTTPS request
+
+ """
+
+ if opener:
+ _open = opener.open
+ else:
+ _open = urlopen
+
+ try:
+ uh = _open(request)
+ if request.get_full_url() != uh.geturl():
+ printer("Redirected to %s" % uh.geturl(), debug=True)
+ return uh, False
+ except HTTP_ERRORS:
+ e = get_exception()
+ return None, e
+
+
+def get_response_stream(response):
+ """Helper function to return either a Gzip reader if
+ ``Content-Encoding`` is ``gzip`` otherwise the response itself
+
+ """
+
+ try:
+ getheader = response.headers.getheader
+ except AttributeError:
+ getheader = response.getheader
+
+ if getheader("content-encoding") == "gzip":
+ return GzipDecodedResponse(response)
+
+ return response
+
+
+def get_attributes_by_tag_name(dom, tag_name):
+ """Retrieve an attribute from an XML document and return it in a
+ consistent format
+
+ Only used with xml.dom.minidom, which is likely only to be used
+ with python versions older than 2.5
+ """
+ elem = dom.getElementsByTagName(tag_name)[0]
+ return dict(list(elem.attributes.items()))
+
+
+def print_dots(shutdown_event):
+ """Built in callback function used by Thread classes for printing
+ status
+ """
+
+ def inner(current, total, start=False, end=False):
+ if event_is_set(shutdown_event):
+ return
+
+ sys.stdout.write(".")
+ if current + 1 == total and end is True:
+ sys.stdout.write("\n")
+ sys.stdout.flush()
+
+ return inner
+
+
+def do_nothing(*args, **kwargs):
+ pass
+
+
+class HTTPDownloader(threading.Thread):
+ """Thread class for retrieving a URL"""
+
+ def __init__(self, i, request, start, timeout, opener=None, shutdown_event=None):
+ threading.Thread.__init__(self)
+ self.request = request
+ self.result = [0]
+ self.starttime = start
+ self.timeout = timeout
+ self.i = i
+ if opener:
+ self._opener = opener.open
+ else:
+ self._opener = urlopen
+
+ if shutdown_event:
+ self._shutdown_event = shutdown_event
+ else:
+ self._shutdown_event = FakeShutdownEvent()
+
+ def run(self):
+ try:
+ if (timeit.default_timer() - self.starttime) <= self.timeout:
+ f = self._opener(self.request)
+ while (
+ not event_is_set(self._shutdown_event) and (timeit.default_timer() - self.starttime) <= self.timeout
+ ):
+ self.result.append(len(f.read(10240)))
+ if self.result[-1] == 0:
+ break
+ f.close()
+ except IOError:
+ pass
+ except HTTP_ERRORS:
+ pass
+
+
+class HTTPUploaderData(object):
+ """File like object to improve cutting off the upload once the timeout
+ has been reached
+ """
+
+ def __init__(self, length, start, timeout, shutdown_event=None):
+ self.length = length
+ self.start = start
+ self.timeout = timeout
+
+ if shutdown_event:
+ self._shutdown_event = shutdown_event
+ else:
+ self._shutdown_event = FakeShutdownEvent()
+
+ self._data = None
+
+ self.total = [0]
+
+ def pre_allocate(self):
+ chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ multiplier = int(round(int(self.length) / 36.0))
+ IO = BytesIO or StringIO
+ try:
+ self._data = IO(("content1=%s" % (chars * multiplier)[0 : int(self.length) - 9]).encode())
+ except MemoryError:
+ raise SpeedtestCLIError("Insufficient memory to pre-allocate upload data. Please " "use --no-pre-allocate")
+
+ @property
+ def data(self):
+ if not self._data:
+ self.pre_allocate()
+ return self._data
+
+ def read(self, n=10240):
+ if (timeit.default_timer() - self.start) <= self.timeout and not event_is_set(self._shutdown_event):
+ chunk = self.data.read(n)
+ self.total.append(len(chunk))
+ return chunk
+ else:
+ raise SpeedtestUploadTimeout()
+
+ def __len__(self):
+ return self.length
+
+
+class HTTPUploader(threading.Thread):
+ """Thread class for putting a URL"""
+
+ def __init__(self, i, request, start, size, timeout, opener=None, shutdown_event=None):
+ threading.Thread.__init__(self)
+ self.request = request
+ self.request.data.start = self.starttime = start
+ self.size = size
+ self.result = 0
+ self.timeout = timeout
+ self.i = i
+
+ if opener:
+ self._opener = opener.open
+ else:
+ self._opener = urlopen
+
+ if shutdown_event:
+ self._shutdown_event = shutdown_event
+ else:
+ self._shutdown_event = FakeShutdownEvent()
+
+ def run(self):
+ request = self.request
+ try:
+ if (timeit.default_timer() - self.starttime) <= self.timeout and not event_is_set(self._shutdown_event):
+ try:
+ f = self._opener(request)
+ except TypeError:
+ # PY24 expects a string or buffer
+ # This also causes issues with Ctrl-C, but we will concede
+ # for the moment that Ctrl-C on PY24 isn't immediate
+ request = build_request(self.request.get_full_url(), data=request.data.read(self.size))
+ f = self._opener(request)
+ f.read(11)
+ f.close()
+ self.result = sum(self.request.data.total)
+ else:
+ self.result = 0
+ except (IOError, SpeedtestUploadTimeout):
+ self.result = sum(self.request.data.total)
+ except HTTP_ERRORS:
+ self.result = 0
+
+
+class SpeedtestResults(object):
+ """Class for holding the results of a speedtest, including:
+
+ Download speed
+ Upload speed
+ Ping/Latency to test server
+ Data about server that the test was run against
+
+ Additionally this class can return a result data as a dictionary or CSV,
+ as well as submit a POST of the result data to the speedtest.net API
+ to get a share results image link.
+ """
+
+ def __init__(self, download=0, upload=0, ping=0, server=None, client=None, opener=None, secure=False):
+ self.download = download
+ self.upload = upload
+ self.ping = ping
+ if server is None:
+ self.server = {}
+ else:
+ self.server = server
+ self.client = client or {}
+
+ self._share = None
+ self.timestamp = "%sZ" % datetime.datetime.utcnow().isoformat()
+ self.bytes_received = 0
+ self.bytes_sent = 0
+
+ if opener:
+ self._opener = opener
+ else:
+ self._opener = build_opener()
+
+ self._secure = secure
+
+ def __repr__(self):
+ return repr(self.dict())
+
+ def share(self):
+ """POST data to the speedtest.net API to obtain a share results
+ link
+ """
+
+ if self._share:
+ return self._share
+
+ download = int(round(self.download / 1000.0, 0))
+ ping = int(round(self.ping, 0))
+ upload = int(round(self.upload / 1000.0, 0))
+
+ # Build the request to send results back to speedtest.net
+ # We use a list instead of a dict because the API expects parameters
+ # in a certain order
+ api_data = [
+ "recommendedserverid=%s" % self.server["id"],
+ "ping=%s" % ping,
+ "screenresolution=",
+ "promo=",
+ "download=%s" % download,
+ "screendpi=",
+ "upload=%s" % upload,
+ "testmethod=http",
+ "hash=%s" % md5(("%s-%s-%s-%s" % (ping, upload, download, "297aae72")).encode()).hexdigest(),
+ "touchscreen=none",
+ "startmode=pingselect",
+ "accuracy=1",
+ "bytesreceived=%s" % self.bytes_received,
+ "bytessent=%s" % self.bytes_sent,
+ "serverid=%s" % self.server["id"],
+ ]
+
+ headers = {"Referer": "http://c.speedtest.net/flash/speedtest.swf"}
+ request = build_request(
+ "://www.speedtest.net/api/api.php", data="&".join(api_data).encode(), headers=headers, secure=self._secure
+ )
+ f, e = catch_request(request, opener=self._opener)
+ if e:
+ raise ShareResultsConnectFailure(e)
+
+ response = f.read()
+ code = f.code
+ f.close()
+
+ if int(code) != 200:
+ raise ShareResultsSubmitFailure("Could not submit results to " "speedtest.net")
+
+ qsargs = parse_qs(response.decode())
+ resultid = qsargs.get("resultid")
+ if not resultid or len(resultid) != 1:
+ raise ShareResultsSubmitFailure("Could not submit results to " "speedtest.net")
+
+ self._share = "http://www.speedtest.net/result/%s.png" % resultid[0]
+
+ return self._share
+
+ def dict(self):
+ """Return dictionary of result data"""
+
+ return {
+ "download": self.download,
+ "upload": self.upload,
+ "ping": self.ping,
+ "server": self.server,
+ "timestamp": self.timestamp,
+ "bytes_sent": self.bytes_sent,
+ "bytes_received": self.bytes_received,
+ "share": self._share,
+ "client": self.client,
+ }
+
+ @staticmethod
+ def csv_header(delimiter=","):
+ """Return CSV Headers"""
+
+ row = [
+ "Server ID",
+ "Sponsor",
+ "Server Name",
+ "Timestamp",
+ "Distance",
+ "Ping",
+ "Download",
+ "Upload",
+ "Share",
+ "IP Address",
+ ]
+ out = StringIO()
+ writer = csv.writer(out, delimiter=delimiter, lineterminator="")
+ writer.writerow([to_utf8(v) for v in row])
+ return out.getvalue()
+
+ def csv(self, delimiter=","):
+ """Return data in CSV format"""
+
+ data = self.dict()
+ out = StringIO()
+ writer = csv.writer(out, delimiter=delimiter, lineterminator="")
+ row = [
+ data["server"]["id"],
+ data["server"]["sponsor"],
+ data["server"]["name"],
+ data["timestamp"],
+ data["server"]["d"],
+ data["ping"],
+ data["download"],
+ data["upload"],
+ self._share or "",
+ self.client["ip"],
+ ]
+ writer.writerow([to_utf8(v) for v in row])
+ return out.getvalue()
+
+ def json(self, pretty=False):
+ """Return data in JSON format"""
+
+ kwargs = {}
+ if pretty:
+ kwargs.update({"indent": 4, "sort_keys": True})
+ return json.dumps(self.dict(), **kwargs)
+
+
+class Speedtest(object):
+ """Class for performing standard speedtest.net testing operations"""
+
+ def __init__(self, config=None, source_address=None, timeout=10, secure=False, shutdown_event=None):
+ self.config = {}
+
+ self._source_address = source_address
+ self._timeout = timeout
+ self._opener = build_opener(source_address, timeout)
+
+ self._secure = secure
+
+ if shutdown_event:
+ self._shutdown_event = shutdown_event
+ else:
+ self._shutdown_event = FakeShutdownEvent()
+
+ self.get_config()
+ if config is not None:
+ self.config.update(config)
+
+ self.servers = {}
+ self.closest = []
+ self._best = {}
+
+ self.results = SpeedtestResults(
+ client=self.config["client"],
+ opener=self._opener,
+ secure=secure,
+ )
+
+ @property
+ def best(self):
+ if not self._best:
+ self.get_best_server()
+ return self._best
+
+ def get_config(self):
+ """Download the speedtest.net configuration and return only the data
+ we are interested in
+ """
+
+ headers = {}
+ if gzip:
+ headers["Accept-Encoding"] = "gzip"
+ request = build_request("://www.speedtest.net/speedtest-config.php", headers=headers, secure=self._secure)
+ uh, e = catch_request(request, opener=self._opener)
+ if e:
+ raise ConfigRetrievalError(e)
+ configxml_list = []
+
+ stream = get_response_stream(uh)
+
+ while 1:
+ try:
+ configxml_list.append(stream.read(1024))
+ except (OSError, EOFError):
+ raise ConfigRetrievalError(get_exception())
+ if len(configxml_list[-1]) == 0:
+ break
+ stream.close()
+ uh.close()
+
+ if int(uh.code) != 200:
+ return None
+
+ configxml = "".encode().join(configxml_list)
+
+ printer("Config XML:\n%s" % configxml, debug=True)
+
+ try:
+ try:
+ root = ET.fromstring(configxml)
+ except ET.ParseError:
+ e = get_exception()
+ raise SpeedtestConfigError("Malformed speedtest.net configuration: %s" % e)
+ server_config = root.find("server-config").attrib
+ download = root.find("download").attrib
+ upload = root.find("upload").attrib
+ # times = root.find('times').attrib
+ client = root.find("client").attrib
+
+ except AttributeError:
+ try:
+ root = DOM.parseString(configxml)
+ except ExpatError:
+ e = get_exception()
+ raise SpeedtestConfigError("Malformed speedtest.net configuration: %s" % e)
+ server_config = get_attributes_by_tag_name(root, "server-config")
+ download = get_attributes_by_tag_name(root, "download")
+ upload = get_attributes_by_tag_name(root, "upload")
+ # times = get_attributes_by_tag_name(root, 'times')
+ client = get_attributes_by_tag_name(root, "client")
+
+ ignore_servers = [int(i) for i in server_config["ignoreids"].split(",") if i]
+
+ ratio = int(upload["ratio"])
+ upload_max = int(upload["maxchunkcount"])
+ up_sizes = [32768, 65536, 131072, 262144, 524288, 1048576, 7340032]
+ sizes = {"upload": up_sizes[ratio - 1 :], "download": [350, 500, 750, 1000, 1500, 2000, 2500, 3000, 3500, 4000]}
+
+ size_count = len(sizes["upload"])
+
+ upload_count = int(math.ceil(upload_max / size_count))
+
+ counts = {"upload": upload_count, "download": int(download["threadsperurl"])}
+
+ threads = {"upload": int(upload["threads"]), "download": int(server_config["threadcount"]) * 2}
+
+ length = {"upload": int(upload["testlength"]), "download": int(download["testlength"])}
+
+ self.config.update(
+ {
+ "client": client,
+ "ignore_servers": ignore_servers,
+ "sizes": sizes,
+ "counts": counts,
+ "threads": threads,
+ "length": length,
+ "upload_max": upload_count * size_count,
+ }
+ )
+
+ try:
+ self.lat_lon = (float(client["lat"]), float(client["lon"]))
+ except ValueError:
+ raise SpeedtestConfigError("Unknown location: lat=%r lon=%r" % (client.get("lat"), client.get("lon")))
+
+ printer("Config:\n%r" % self.config, debug=True)
+
+ return self.config
+
+ def get_servers(self, servers=None, exclude=None):
+ """Retrieve a the list of speedtest.net servers, optionally filtered
+ to servers matching those specified in the ``servers`` argument
+ """
+ if servers is None:
+ servers = []
+
+ if exclude is None:
+ exclude = []
+
+ self.servers.clear()
+
+ for server_list in (servers, exclude):
+ for i, s in enumerate(server_list):
+ try:
+ server_list[i] = int(s)
+ except ValueError:
+ raise InvalidServerIDType("%s is an invalid server type, must be int" % s)
+
+ urls = [
+ "://www.speedtest.net/speedtest-servers-static.php",
+ "http://c.speedtest.net/speedtest-servers-static.php",
+ "://www.speedtest.net/speedtest-servers.php",
+ "http://c.speedtest.net/speedtest-servers.php",
+ ]
+
+ headers = {}
+ if gzip:
+ headers["Accept-Encoding"] = "gzip"
+
+ errors = []
+ for url in urls:
+ try:
+ request = build_request(
+ "%s?threads=%s" % (url, self.config["threads"]["download"]), headers=headers, secure=self._secure
+ )
+ uh, e = catch_request(request, opener=self._opener)
+ if e:
+ errors.append("%s" % e)
+ raise ServersRetrievalError()
+
+ stream = get_response_stream(uh)
+
+ serversxml_list = []
+ while 1:
+ try:
+ serversxml_list.append(stream.read(1024))
+ except (OSError, EOFError):
+ raise ServersRetrievalError(get_exception())
+ if len(serversxml_list[-1]) == 0:
+ break
+
+ stream.close()
+ uh.close()
+
+ if int(uh.code) != 200:
+ raise ServersRetrievalError()
+
+ serversxml = "".encode().join(serversxml_list)
+
+ printer("Servers XML:\n%s" % serversxml, debug=True)
+
+ try:
+ try:
+ try:
+ root = ET.fromstring(serversxml)
+ except ET.ParseError:
+ e = get_exception()
+ raise SpeedtestServersError("Malformed speedtest.net server list: %s" % e)
+ elements = etree_iter(root, "server")
+ except AttributeError:
+ try:
+ root = DOM.parseString(serversxml)
+ except ExpatError:
+ e = get_exception()
+ raise SpeedtestServersError("Malformed speedtest.net server list: %s" % e)
+ elements = root.getElementsByTagName("server")
+ except (SyntaxError, xml.parsers.expat.ExpatError):
+ raise ServersRetrievalError()
+
+ for server in elements:
+ try:
+ attrib = server.attrib
+ except AttributeError:
+ attrib = dict(list(server.attributes.items()))
+
+ if servers and int(attrib.get("id")) not in servers:
+ continue
+
+ if int(attrib.get("id")) in self.config["ignore_servers"] or int(attrib.get("id")) in exclude:
+ continue
+
+ try:
+ d = distance(self.lat_lon, (float(attrib.get("lat")), float(attrib.get("lon"))))
+ except Exception:
+ continue
+
+ attrib["d"] = d
+
+ try:
+ self.servers[d].append(attrib)
+ except KeyError:
+ self.servers[d] = [attrib]
+
+ break
+
+ except ServersRetrievalError:
+ continue
+
+ if (servers or exclude) and not self.servers:
+ raise NoMatchedServers()
+
+ return self.servers
+
+ def set_mini_server(self, server):
+ """Instead of querying for a list of servers, set a link to a
+ speedtest mini server
+ """
+
+ urlparts = urlparse(server)
+
+ name, ext = os.path.splitext(urlparts[2])
+ if ext:
+ url = os.path.dirname(server)
+ else:
+ url = server
+
+ request = build_request(url)
+ uh, e = catch_request(request, opener=self._opener)
+ if e:
+ raise SpeedtestMiniConnectFailure("Failed to connect to %s" % server)
+ else:
+ text = uh.read()
+ uh.close()
+
+ extension = re.findall('upload_?[Ee]xtension: "([^"]+)"', text.decode())
+ if not extension:
+ for ext in ["php", "asp", "aspx", "jsp"]:
+ try:
+ f = self._opener.open("%s/speedtest/upload.%s" % (url, ext))
+ except Exception:
+ pass
+ else:
+ data = f.read().strip().decode()
+ if f.code == 200 and len(data.splitlines()) == 1 and re.match("size=[0-9]", data):
+ extension = [ext]
+ break
+ if not urlparts or not extension:
+ raise InvalidSpeedtestMiniServer("Invalid Speedtest Mini Server: " "%s" % server)
+
+ self.servers = [
+ {
+ "sponsor": "Speedtest Mini",
+ "name": urlparts[1],
+ "d": 0,
+ "url": "%s/speedtest/upload.%s" % (url.rstrip("/"), extension[0]),
+ "latency": 0,
+ "id": 0,
+ }
+ ]
+
+ return self.servers
+
+ def get_closest_servers(self, limit=5):
+ """Limit servers to the closest speedtest.net servers based on
+ geographic distance
+ """
+
+ if not self.servers:
+ self.get_servers()
+
+ for d in sorted(self.servers.keys()):
+ for s in self.servers[d]:
+ self.closest.append(s)
+ if len(self.closest) == limit:
+ break
+ else:
+ continue
+ break
+
+ printer("Closest Servers:\n%r" % self.closest, debug=True)
+ return self.closest
+
+ def get_best_server(self, servers=None):
+ """Perform a speedtest.net "ping" to determine which speedtest.net
+ server has the lowest latency
+ """
+
+ if not servers:
+ if not self.closest:
+ servers = self.get_closest_servers()
+ servers = self.closest
+
+ if self._source_address:
+ source_address_tuple = (self._source_address, 0)
+ else:
+ source_address_tuple = None
+
+ user_agent = build_user_agent()
+
+ results = {}
+ for server in servers:
+ cum = []
+ url = os.path.dirname(server["url"])
+ stamp = int(timeit.time.time() * 1000)
+ latency_url = "%s/latency.txt?x=%s" % (url, stamp)
+ for i in range(0, 3):
+ this_latency_url = "%s.%s" % (latency_url, i)
+ printer("%s %s" % ("GET", this_latency_url), debug=True)
+ urlparts = urlparse(latency_url)
+ try:
+ if urlparts[0] == "https":
+ h = SpeedtestHTTPSConnection(urlparts[1], source_address=source_address_tuple)
+ else:
+ h = SpeedtestHTTPConnection(urlparts[1], source_address=source_address_tuple)
+ headers = {"User-Agent": user_agent}
+ path = "%s?%s" % (urlparts[2], urlparts[4])
+ start = timeit.default_timer()
+ h.request("GET", path, headers=headers)
+ r = h.getresponse()
+ total = timeit.default_timer() - start
+ except HTTP_ERRORS:
+ e = get_exception()
+ printer("ERROR: %r" % e, debug=True)
+ cum.append(3600)
+ continue
+
+ text = r.read(9)
+ if int(r.status) == 200 and text == "test=test".encode():
+ cum.append(total)
+ else:
+ cum.append(3600)
+ h.close()
+
+ avg = round((sum(cum) / 6) * 1000.0, 3)
+ results[avg] = server
+
+ try:
+ fastest = sorted(results.keys())[0]
+ except IndexError:
+ raise SpeedtestBestServerFailure("Unable to connect to servers to " "test latency.")
+ best = results[fastest]
+ best["latency"] = fastest
+
+ self.results.ping = fastest
+ self.results.server = best
+
+ self._best.update(best)
+ printer("Best Server:\n%r" % best, debug=True)
+ return best
+
+ def download(self, callback=do_nothing, threads=None):
+ """Test download speed against speedtest.net
+
+ A ``threads`` value of ``None`` will fall back to those dictated
+ by the speedtest.net configuration
+ """
+
+ urls = []
+ for size in self.config["sizes"]["download"]:
+ for _ in range(0, self.config["counts"]["download"]):
+ urls.append("%s/random%sx%s.jpg" % (os.path.dirname(self.best["url"]), size, size))
+
+ request_count = len(urls)
+ requests = []
+ for i, url in enumerate(urls):
+ requests.append(build_request(url, bump=i, secure=self._secure))
+
+ max_threads = threads or self.config["threads"]["download"]
+ in_flight = {"threads": 0}
+
+ def producer(q, requests, request_count):
+ for i, request in enumerate(requests):
+ thread = HTTPDownloader(
+ i,
+ request,
+ start,
+ self.config["length"]["download"],
+ opener=self._opener,
+ shutdown_event=self._shutdown_event,
+ )
+ while in_flight["threads"] >= max_threads:
+ timeit.time.sleep(0.001)
+ thread.start()
+ q.put(thread, True)
+ in_flight["threads"] += 1
+ callback(i, request_count, start=True)
+
+ finished = []
+
+ def consumer(q, request_count):
+ _is_alive = thread_is_alive
+ while len(finished) < request_count:
+ thread = q.get(True)
+ while _is_alive(thread):
+ thread.join(timeout=0.001)
+ in_flight["threads"] -= 1
+ finished.append(sum(thread.result))
+ callback(thread.i, request_count, end=True)
+
+ q = Queue(max_threads)
+ prod_thread = threading.Thread(target=producer, args=(q, requests, request_count))
+ cons_thread = threading.Thread(target=consumer, args=(q, request_count))
+ start = timeit.default_timer()
+ prod_thread.start()
+ cons_thread.start()
+ _is_alive = thread_is_alive
+ while _is_alive(prod_thread):
+ prod_thread.join(timeout=0.001)
+ while _is_alive(cons_thread):
+ cons_thread.join(timeout=0.001)
+
+ stop = timeit.default_timer()
+ self.results.bytes_received = sum(finished)
+ self.results.download = (self.results.bytes_received / (stop - start)) * 8.0
+ if self.results.download > 100000:
+ self.config["threads"]["upload"] = 8
+ return self.results.download
+
+ def upload(self, callback=do_nothing, pre_allocate=True, threads=None):
+ """Test upload speed against speedtest.net
+
+ A ``threads`` value of ``None`` will fall back to those dictated
+ by the speedtest.net configuration
+ """
+
+ sizes = []
+
+ for size in self.config["sizes"]["upload"]:
+ for _ in range(0, self.config["counts"]["upload"]):
+ sizes.append(size)
+
+ # request_count = len(sizes)
+ request_count = self.config["upload_max"]
+
+ requests = []
+ for i, size in enumerate(sizes):
+ # We set ``0`` for ``start`` and handle setting the actual
+ # ``start`` in ``HTTPUploader`` to get better measurements
+ data = HTTPUploaderData(size, 0, self.config["length"]["upload"], shutdown_event=self._shutdown_event)
+ if pre_allocate:
+ data.pre_allocate()
+
+ headers = {"Content-length": size}
+ requests.append((build_request(self.best["url"], data, secure=self._secure, headers=headers), size))
+
+ max_threads = threads or self.config["threads"]["upload"]
+ in_flight = {"threads": 0}
+
+ def producer(q, requests, request_count):
+ for i, request in enumerate(requests[:request_count]):
+ thread = HTTPUploader(
+ i,
+ request[0],
+ start,
+ request[1],
+ self.config["length"]["upload"],
+ opener=self._opener,
+ shutdown_event=self._shutdown_event,
+ )
+ while in_flight["threads"] >= max_threads:
+ timeit.time.sleep(0.001)
+ thread.start()
+ q.put(thread, True)
+ in_flight["threads"] += 1
+ callback(i, request_count, start=True)
+
+ finished = []
+
+ def consumer(q, request_count):
+ _is_alive = thread_is_alive
+ while len(finished) < request_count:
+ thread = q.get(True)
+ while _is_alive(thread):
+ thread.join(timeout=0.001)
+ in_flight["threads"] -= 1
+ finished.append(thread.result)
+ callback(thread.i, request_count, end=True)
+
+ q = Queue(threads or self.config["threads"]["upload"])
+ prod_thread = threading.Thread(target=producer, args=(q, requests, request_count))
+ cons_thread = threading.Thread(target=consumer, args=(q, request_count))
+ start = timeit.default_timer()
+ prod_thread.start()
+ cons_thread.start()
+ _is_alive = thread_is_alive
+ while _is_alive(prod_thread):
+ prod_thread.join(timeout=0.1)
+ while _is_alive(cons_thread):
+ cons_thread.join(timeout=0.1)
+
+ stop = timeit.default_timer()
+ self.results.bytes_sent = sum(finished)
+ self.results.upload = (self.results.bytes_sent / (stop - start)) * 8.0
+ return self.results.upload
+
+
+def ctrl_c(shutdown_event):
+ """Catch Ctrl-C key sequence and set a SHUTDOWN_EVENT for our threaded
+ operations
+ """
+
+ def inner(signum, frame):
+ shutdown_event.set()
+ printer("\nCancelling...", error=True)
+ sys.exit(0)
+
+ return inner
+
+
+def version():
+ """Print the version"""
+
+ printer("speedtest-cli %s" % __version__)
+ printer("Python %s" % sys.version.replace("\n", ""))
+ sys.exit(0)
+
+
+def csv_header(delimiter=","):
+ """Print the CSV Headers"""
+
+ printer(SpeedtestResults.csv_header(delimiter=delimiter))
+ sys.exit(0)
+
+
+def parse_args():
+ """Function to handle building and parsing of command line arguments"""
+ description = (
+ "Command line interface for testing internet bandwidth using "
+ "speedtest.net.\n"
+ "------------------------------------------------------------"
+ "--------------\n"
+ "https://github.com/sivel/speedtest-cli"
+ )
+
+ parser = ArgParser(description=description)
+ # Give optparse.OptionParser an `add_argument` method for
+ # compatibility with argparse.ArgumentParser
+ try:
+ parser.add_argument = parser.add_option
+ except AttributeError:
+ pass
+ parser.add_argument(
+ "--no-download",
+ dest="download",
+ default=True,
+ action="store_const",
+ const=False,
+ help="Do not perform download test",
+ )
+ parser.add_argument(
+ "--no-upload", dest="upload", default=True, action="store_const", const=False, help="Do not perform upload test"
+ )
+ parser.add_argument(
+ "--single",
+ default=False,
+ action="store_true",
+ help="Only use a single connection instead of " "multiple. This simulates a typical file " "transfer.",
+ )
+ parser.add_argument(
+ "--bytes",
+ dest="units",
+ action="store_const",
+ const=("byte", 8),
+ default=("bit", 1),
+ help="Display values in bytes instead of bits. Does "
+ "not affect the image generated by --share, nor "
+ "output from --json or --csv",
+ )
+ parser.add_argument(
+ "--share",
+ action="store_true",
+ help="Generate and provide a URL to the speedtest.net " "share results image, not displayed with --csv",
+ )
+ parser.add_argument(
+ "--simple", action="store_true", default=False, help="Suppress verbose output, only show basic " "information"
+ )
+ parser.add_argument(
+ "--csv",
+ action="store_true",
+ default=False,
+ help="Suppress verbose output, only show basic "
+ "information in CSV format. Speeds listed in "
+ "bit/s and not affected by --bytes",
+ )
+ parser.add_argument(
+ "--csv-delimiter",
+ default=",",
+ type=PARSER_TYPE_STR,
+ help="Single character delimiter to use in CSV " 'output. Default ","',
+ )
+ parser.add_argument("--csv-header", action="store_true", default=False, help="Print CSV headers")
+ parser.add_argument(
+ "--json",
+ action="store_true",
+ default=False,
+ help="Suppress verbose output, only show basic "
+ "information in JSON format. Speeds listed in "
+ "bit/s and not affected by --bytes",
+ )
+ parser.add_argument(
+ "--list", action="store_true", help="Display a list of speedtest.net servers " "sorted by distance"
+ )
+ parser.add_argument(
+ "--server",
+ type=PARSER_TYPE_INT,
+ action="append",
+ help="Specify a server ID to test against. Can be " "supplied multiple times",
+ )
+ parser.add_argument(
+ "--exclude",
+ type=PARSER_TYPE_INT,
+ action="append",
+ help="Exclude a server from selection. Can be " "supplied multiple times",
+ )
+ parser.add_argument("--mini", help="URL of the Speedtest Mini server")
+ parser.add_argument("--source", help="Source IP address to bind to")
+ parser.add_argument("--timeout", default=10, type=PARSER_TYPE_FLOAT, help="HTTP timeout in seconds. Default 10")
+ parser.add_argument(
+ "--secure",
+ action="store_true",
+ help="Use HTTPS instead of HTTP when communicating " "with speedtest.net operated servers",
+ )
+ parser.add_argument(
+ "--no-pre-allocate",
+ dest="pre_allocate",
+ action="store_const",
+ default=True,
+ const=False,
+ help="Do not pre allocate upload data. Pre allocation "
+ "is enabled by default to improve upload "
+ "performance. To support systems with "
+ "insufficient memory, use this option to avoid a "
+ "MemoryError",
+ )
+ parser.add_argument("--version", action="store_true", help="Show the version number and exit")
+ parser.add_argument("--debug", action="store_true", help=ARG_SUPPRESS, default=ARG_SUPPRESS)
+
+ options = parser.parse_args()
+ if isinstance(options, tuple):
+ args = options[0]
+ else:
+ args = options
+ return args
+
+
+def validate_optional_args(args):
+ """Check if an argument was provided that depends on a module that may
+ not be part of the Python standard library.
+
+ If such an argument is supplied, and the module does not exist, exit
+ with an error stating which module is missing.
+ """
+ optional_args = {
+ "json": ("json/simplejson python module", json),
+ "secure": ("SSL support", HTTPSConnection),
+ }
+
+ for arg, info in optional_args.items():
+ if getattr(args, arg, False) and info[1] is None:
+ raise SystemExit("%s is not installed. --%s is " "unavailable" % (info[0], arg))
+
+
+def printer(string, quiet=False, debug=False, error=False, **kwargs):
+ """Helper function print a string with various features"""
+
+ if debug and not DEBUG:
+ return
+
+ if debug:
+ if sys.stdout.isatty():
+ out = "\033[1;30mDEBUG: %s\033[0m" % string
+ else:
+ out = "DEBUG: %s" % string
+ else:
+ out = string
+
+ if error:
+ kwargs["file"] = sys.stderr
+
+ if not quiet:
+ print_(out, **kwargs)
+
+
+def shell():
+ """Run the full speedtest.net test"""
+
+ global DEBUG
+ shutdown_event = threading.Event()
+
+ signal.signal(signal.SIGINT, ctrl_c(shutdown_event))
+
+ args = parse_args()
+
+ # Print the version and exit
+ if args.version:
+ version()
+
+ if not args.download and not args.upload:
+ raise SpeedtestCLIError("Cannot supply both --no-download and " "--no-upload")
+
+ if len(args.csv_delimiter) != 1:
+ raise SpeedtestCLIError("--csv-delimiter must be a single character")
+
+ if args.csv_header:
+ csv_header(args.csv_delimiter)
+
+ validate_optional_args(args)
+
+ debug = getattr(args, "debug", False)
+ if debug == "SUPPRESSHELP":
+ debug = False
+ if debug:
+ DEBUG = True
+
+ if args.simple or args.csv or args.json:
+ quiet = True
+ else:
+ quiet = False
+
+ if args.csv or args.json:
+ machine_format = True
+ else:
+ machine_format = False
+
+ # Don't set a callback if we are running quietly
+ if quiet or debug:
+ callback = do_nothing
+ else:
+ callback = print_dots(shutdown_event)
+
+ printer("Retrieving speedtest.net configuration...", quiet)
+ try:
+ speedtest = Speedtest(source_address=args.source, timeout=args.timeout, secure=args.secure)
+ except (ConfigRetrievalError,) + HTTP_ERRORS:
+ printer("Cannot retrieve speedtest configuration", error=True)
+ raise SpeedtestCLIError(get_exception())
+
+ if args.list:
+ try:
+ speedtest.get_servers()
+ except (ServersRetrievalError,) + HTTP_ERRORS:
+ printer("Cannot retrieve speedtest server list", error=True)
+ raise SpeedtestCLIError(get_exception())
+
+ for _, servers in sorted(speedtest.servers.items()):
+ for server in servers:
+ line = "%(id)5s) %(sponsor)s (%(name)s, %(country)s) " "[%(d)0.2f km]" % server
+ try:
+ printer(line)
+ except IOError:
+ e = get_exception()
+ if e.errno != errno.EPIPE:
+ raise
+ sys.exit(0)
+
+ printer("Testing from %(isp)s (%(ip)s)..." % speedtest.config["client"], quiet)
+
+ if not args.mini:
+ printer("Retrieving speedtest.net server list...", quiet)
+ try:
+ speedtest.get_servers(servers=args.server, exclude=args.exclude)
+ except NoMatchedServers:
+ raise SpeedtestCLIError("No matched servers: %s" % ", ".join("%s" % s for s in args.server))
+ except (ServersRetrievalError,) + HTTP_ERRORS:
+ printer("Cannot retrieve speedtest server list", error=True)
+ raise SpeedtestCLIError(get_exception())
+ except InvalidServerIDType:
+ raise SpeedtestCLIError(
+ "%s is an invalid server type, must " "be an int" % ", ".join("%s" % s for s in args.server)
+ )
+
+ if args.server and len(args.server) == 1:
+ printer("Retrieving information for the selected server...", quiet)
+ else:
+ printer("Selecting best server based on ping...", quiet)
+ speedtest.get_best_server()
+ elif args.mini:
+ speedtest.get_best_server(speedtest.set_mini_server(args.mini))
+
+ results = speedtest.results
+
+ printer("Hosted by %(sponsor)s (%(name)s) [%(d)0.2f km]: " "%(latency)s ms" % results.server, quiet)
+
+ if args.download:
+ printer("Testing download speed", quiet, end=("", "\n")[bool(debug)])
+ speedtest.download(callback=callback, threads=(None, 1)[args.single])
+ printer("Download: %0.2f M%s/s" % ((results.download / 1000.0 / 1000.0) / args.units[1], args.units[0]), quiet)
+ else:
+ printer("Skipping download test", quiet)
+
+ if args.upload:
+ printer("Testing upload speed", quiet, end=("", "\n")[bool(debug)])
+ speedtest.upload(callback=callback, pre_allocate=args.pre_allocate, threads=(None, 1)[args.single])
+ printer("Upload: %0.2f M%s/s" % ((results.upload / 1000.0 / 1000.0) / args.units[1], args.units[0]), quiet)
+ else:
+ printer("Skipping upload test", quiet)
+
+ printer("Results:\n%r" % results.dict(), debug=True)
+
+ if not args.simple and args.share:
+ results.share()
+
+ if args.simple:
+ printer(
+ "Ping: %s ms\nDownload: %0.2f M%s/s\nUpload: %0.2f M%s/s"
+ % (
+ results.ping,
+ (results.download / 1000.0 / 1000.0) / args.units[1],
+ args.units[0],
+ (results.upload / 1000.0 / 1000.0) / args.units[1],
+ args.units[0],
+ )
+ )
+ elif args.csv:
+ printer(results.csv(delimiter=args.csv_delimiter))
+ elif args.json:
+ printer(results.json())
+
+ if args.share and not machine_format:
+ printer("Share results: %s" % results.share())
+
+
+def main():
+ try:
+ shell()
+ except KeyboardInterrupt:
+ printer("\nCancelling...", error=True)
+ except (SpeedtestException, SystemExit):
+ e = get_exception()
+ # Ignore a successful exit, or argparse exit
+ if getattr(e, "code", 1) not in (0, 2):
+ msg = "%s" % e
+ if not msg:
+ msg = "%r" % e
+ raise SystemExit("ERROR: %s" % msg)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/petals/examples/prompt-tuning-personachat.ipynb b/petals/examples/prompt-tuning-personachat.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..77312fda01d62cf5b0e62711816dd5744a88d523
--- /dev/null
+++ b/petals/examples/prompt-tuning-personachat.ipynb
@@ -0,0 +1,339 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "a07e0f5e",
+ "metadata": {},
+ "source": [
+ "\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "# Distributed Bloom for Text Generation using Prompt Tuning\n",
+ "\n",
+ "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
+ "\n",
+ "We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n",
+ "\n",
+ "To open this notebook in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a3f8526f",
+ "metadata": {},
+ "source": [
+ "First, we have to prepare all dependencies."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "73bbc648",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# This block is only need for colab users. It will change nothing if you are running this notebook locally.\n",
+ "import subprocess\n",
+ "import sys\n",
+ "\n",
+ "\n",
+ "IN_COLAB = 'google.colab' in sys.modules\n",
+ "\n",
+ "if IN_COLAB:\n",
+ " subprocess.run(['git', 'clone', 'https://github.com/bigscience-workshop/petals'])\n",
+ " subprocess.run(['pip', 'install', '-r', 'petals/requirements.txt'])\n",
+ " subprocess.run(['pip', 'install', 'datasets', 'lib64'])\n",
+ "\n",
+ " try:\n",
+ " subprocess.check_output([\"nvidia-smi\", \"-L\"])\n",
+ " except subprocess.CalledProcessError as e:\n",
+ " subprocess.run(['rm', '-r', '/usr/local/cuda/lib64'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b4ab6ca7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import sys\n",
+ "sys.path.insert(0, \"..\") # for colab change to sys.path.insert(0, './petals/')\n",
+ " \n",
+ "import torch\n",
+ "import transformers\n",
+ "import wandb\n",
+ "from datasets import load_dataset\n",
+ "from tqdm import tqdm\n",
+ "from torch.optim import AdamW\n",
+ "from torch.utils.data import DataLoader\n",
+ "from transformers import get_scheduler\n",
+ "\n",
+ "# Import a Petals model\n",
+ "from src.client.remote_model import DistributedBloomForCausalLM"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1bf07b5d",
+ "metadata": {},
+ "source": [
+ "Let's set some hyperparameters for training:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f04ba4d2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MODEL_NAME = ... # select model you like\n",
+ "INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n",
+ "NUM_PREFIX_TOKENS = 16\n",
+ "DEVICE = 'cpu'\n",
+ "BATCH_SIZE = 4\n",
+ "LR = 1e-2\n",
+ "WEIGHT_DECAY = 0.0\n",
+ "NUM_SAMPLES = 1000\n",
+ "SEED = 42\n",
+ "MODEL_MAX_LENGTH = 256\n",
+ "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d38316bd",
+ "metadata": {},
+ "source": [
+ "Prepare tokenizer and distributed model, connect it to servers."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "03c6e53e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
+ "tokenizer.padding_side = 'right'\n",
+ "tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
+ "model = DistributedBloomForCausalLM.from_pretrained(\n",
+ " MODEL_NAME, \n",
+ " initial_peers=INITIAL_PEERS, \n",
+ " pre_seq_len=NUM_PREFIX_TOKENS, \n",
+ " tuning_mode=TUNING_MODE\n",
+ ").to(DEVICE)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "042e3786",
+ "metadata": {},
+ "source": [
+ "Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9c44d516",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = load_dataset(\"bavard/personachat_truecased\")\n",
+ "\n",
+ "\n",
+ "def chunking(examples):\n",
+ " inputs = [\n",
+ " \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n",
+ " for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n",
+ " for candidate in candidates\n",
+ " ]\n",
+ " return {\"chunks\": inputs}\n",
+ "\n",
+ "\n",
+ "def tokenize(examples):\n",
+ " outputs = {\n",
+ " \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n",
+ " }\n",
+ " outputs[\"labels\"] = outputs[\"input_ids\"]\n",
+ " return outputs\n",
+ "\n",
+ "\n",
+ "tokenized_datasets = (\n",
+ " dataset\n",
+ " .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n",
+ " .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n",
+ ")\n",
+ "\n",
+ "\n",
+ "tokenized_datasets.set_format(\"torch\")\n",
+ "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n",
+ "train_dataloader = DataLoader(\n",
+ " train_dataset.select(list(range(NUM_SAMPLES))),\n",
+ " shuffle=True,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " drop_last=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ef4323fd",
+ "metadata": {},
+ "source": [
+ "Before setting up optimizers, check the model parameters that will be trained."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9cc0ba34",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for n, p in model.named_parameters():\n",
+ " if p.requires_grad:\n",
+ " print(n, p.requires_grad, p.device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "59cffce7",
+ "metadata": {},
+ "source": [
+ "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ef9bf344",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
+ "\n",
+ "lr_scheduler = get_scheduler(\n",
+ " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "423c56d5",
+ "metadata": {},
+ "source": [
+ "Let's initialize wandb for logging and start the training loop!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d9e46807",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wandb.init(\n",
+ " project=\"bloom-personachat\",\n",
+ " config={\n",
+ " \"num_samples\": NUM_SAMPLES,\n",
+ " \"batch_size\": BATCH_SIZE,\n",
+ " \"learning_rate\": LR,\n",
+ " \"weight_decay\": WEIGHT_DECAY,\n",
+ " \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n",
+ " \"model_name\": MODEL_NAME,\n",
+ " \"seed\": SEED,\n",
+ " }\n",
+ ")\n",
+ "\n",
+ "for batch in tqdm(train_dataloader):\n",
+ " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
+ "\n",
+ " model.train()\n",
+ " outputs = model(**batch)\n",
+ " loss = outputs.loss\n",
+ " loss.backward()\n",
+ "\n",
+ " optimizer.step()\n",
+ " lr_scheduler.step()\n",
+ " optimizer.zero_grad()\n",
+ "\n",
+ " wandb.log({\"Train Loss\": loss})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0f36cb80",
+ "metadata": {},
+ "source": [
+ "Try to talk with the trained model! Submit an empty input to stop the execution.\n",
+ "\n",
+ "\n",
+ "__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \"interactive\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "720181b7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MAX_TOKENS = 16\n",
+ "TOP_K = 100\n",
+ "TEMPERATURE = 0.6\n",
+ "dialog = \"\"\n",
+ "\n",
+ "while True:\n",
+ " user_phrase = input()\n",
+ " if len(user_phrase) == 0:\n",
+ " break\n",
+ " dialog += f\"{user_phrase}\\n-----\\n\"\n",
+ " inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n",
+ " outputs = model.generate(\n",
+ " inputs,\n",
+ " temperature=TEMPERATURE,\n",
+ " do_sample=True,\n",
+ " top_k=TOP_K,\n",
+ " eos_token_id=tokenizer.eos_token_id,\n",
+ " max_new_tokens=MAX_TOKENS,\n",
+ " )\n",
+ " bloom_answer = tokenizer.batch_decode(outputs)[0]\n",
+ " bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n",
+ " print(bloom_answer)\n",
+ " dialog += f\"{bloom_answer}\\n-----\\n\""
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.8.10 64-bit",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.9"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/petals/pyproject.toml b/petals/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..a5b7e30432d5bdf7b0393ed035231cdad56418d8
--- /dev/null
+++ b/petals/pyproject.toml
@@ -0,0 +1,10 @@
+[tool.black]
+line-length = 120
+required-version = "22.3.0"
+
+[tool.isort]
+profile = "black"
+line_length = 120
+combine_as_imports = true
+combine_star = true
+known_local_folder = ["tests", "cli"]
\ No newline at end of file
diff --git a/petals/requirements-dev.txt b/petals/requirements-dev.txt
new file mode 100644
index 0000000000000000000000000000000000000000..637434d0561511c49349b980c383ff56dbd66aa2
--- /dev/null
+++ b/petals/requirements-dev.txt
@@ -0,0 +1,6 @@
+pytest==6.2.5 # see https://github.com/pytest-dev/pytest/issues/9621
+pytest-forked
+pytest-asyncio==0.16.0
+black==22.3.0
+isort==5.10.1
+psutil
\ No newline at end of file
diff --git a/petals/requirements.txt b/petals/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7ecd566602a1366e2154d0cd67a94b025afa24f2
--- /dev/null
+++ b/petals/requirements.txt
@@ -0,0 +1,8 @@
+torch>=1.12
+bitsandbytes==0.34.0
+accelerate==0.10.0
+huggingface-hub==0.7.0
+transformers==4.21.3
+protobuf>=3.12.2,<4.0.0
+git+https://github.com/learning-at-home/hivemind@94c985d2dc7a79a091e46c755e9f2f4469b164c7
+humanfriendly
diff --git a/petals/src/__init__.py b/petals/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d83c699c77df2ccd35a6ed5612e457baa0b2fc31
--- /dev/null
+++ b/petals/src/__init__.py
@@ -0,0 +1,6 @@
+from src.bloom import *
+from src.client import *
+from src.dht_utils import declare_active_modules, get_remote_module
+
+project_name = "bloomd"
+__version__ = "0.2"
diff --git a/petals/src/bloom/__init__.py b/petals/src/bloom/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d42bf511dfde085ee1502891743b740a86a6ed5f
--- /dev/null
+++ b/petals/src/bloom/__init__.py
@@ -0,0 +1,2 @@
+from src.bloom.block import BloomBlock
+from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
diff --git a/petals/src/bloom/block.py b/petals/src/bloom/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..1898d37ed0b6ff6278252bc5883afdd20267149c
--- /dev/null
+++ b/petals/src/bloom/block.py
@@ -0,0 +1,255 @@
+"""
+Bloom intermediate layer
+Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
+See commit history for authorship.
+"""
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.quantized.dynamic.modules.linear
+
+from src.bloom.ops import (
+ BloomGelu,
+ BloomScaledSoftmax,
+ attention_mask_func,
+ build_alibi_tensor,
+ dropout_add,
+ pre_process_alibi_for_pad,
+ split_tensor_along_last_dim,
+)
+
+
+class BloomAttention(nn.Module):
+ def __init__(self, config, layer_number=None):
+ super().__init__()
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.n_head
+ self.head_dim = self.hidden_size // self.num_heads
+ self.split_size = self.hidden_size
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
+ self.masked_softmax_fusion = config.masked_softmax_fusion
+ self.hidden_dropout = config.hidden_dropout
+
+ if self.head_dim * self.num_heads != self.hidden_size:
+ raise ValueError(
+ f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+
+ # Layer-wise attention scaling
+ self.layer_number = max(1, layer_number)
+ self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
+
+ # Scaled Softmax
+ self.scale_mask_softmax = BloomScaledSoftmax(
+ self.masked_softmax_fusion,
+ attention_mask_func,
+ self.attention_softmax_in_fp32,
+ self.layer_number,
+ )
+
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
+
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ layer_past=None,
+ attention_mask=None,
+ alibi=None,
+ head_mask=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+ if alibi is None:
+ current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
+ alibi = build_alibi_tensor(
+ current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ # hidden_states: [batch_size, seq_length, hidden_size]
+ # apply preprocessing if the input is padded
+ if attention_mask is not None:
+ alibi = pre_process_alibi_for_pad(alibi, attention_mask)
+ # otherwise repeat alibi tensor with the batch size
+ else:
+ alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
+
+ mixed_x_layer = self.query_key_value(hidden_states)
+
+ # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+
+ # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
+ value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
+
+ if use_cache is True:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+
+ # [batch_size, head_dim, q_length, k_length]
+ output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
+
+ # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
+ query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
+
+ # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
+ key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
+
+ # Raw attention scores. [batch_size * num_heads, q_length, k_length]
+ beta = 1.0 / self.layer_number
+
+ matmul_result = torch.baddbmm(
+ alibi,
+ query_layer.transpose(1, 0),
+ key_layer.transpose(1, 0).transpose(1, 2),
+ beta=beta,
+ alpha=(1.0 / self.norm_factor),
+ )
+
+ # change view to [batch_size, num_heads, q_length, k_length]
+ attention_scores = matmul_result.view(*output_size)
+
+ # attention scores and attention mask [b, np, sq, sk]
+ max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
+ attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
+ attention_probs = self.attention_dropout(attention_probs)
+
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ # context layer shape: [batch_size, num_heads, q_length, head_dim]
+ output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
+
+ # change view [k_length, batch_size x num_heads, head_dim]
+ value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
+
+ # change view [batch_size x num_heads, q_length, k_length]
+ attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
+
+ # matmul: [batch_size * num_heads, q_length, head_dim]
+ context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
+
+ # change view [batch_size, num_heads, q_length, head_dim]
+ context_layer = context_layer.view(*output_size)
+
+ # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+
+ # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
+
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ # Output. [q_length, batch_size, hidden_size]
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ output_tensor = self.dense(context_layer)
+ output = output_tensor.transpose(1, 0)
+
+ output = dropout_add(output, residual, self.hidden_dropout, self.training)
+
+ outputs = (output, present)
+ if output_attentions:
+ outputs += (attention_probs,)
+
+ return outputs
+
+
+class BloomMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
+ self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
+ self.hidden_dropout = config.hidden_dropout
+ self.gelu_impl = BloomGelu()
+
+ def forward(self, hidden_states, residual):
+ hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
+ intermediate_output = self.dense_4h_to_h(hidden_states)
+ output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
+ return output
+
+
+class BloomBlock(nn.Module):
+ def __init__(self, config, layer_number=None):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
+ self.n_head = config.n_head
+ self.self_attention = BloomAttention(config, layer_number=layer_number)
+ self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
+
+ self.mlp = BloomMLP(config)
+
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
+ self.hidden_dropout = config.hidden_dropout
+
+ def forward(
+ self,
+ hidden_states,
+ layer_past=None,
+ attention_mask=None,
+ head_mask=None,
+ use_cache=False,
+ output_attentions=False,
+ alibi=None,
+ ):
+ # hidden_states: [batch_size, seq_length, hidden_size]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # Layer norm post the self attention.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # Self attention.
+ attn_outputs = self.self_attention(
+ layernorm_output,
+ residual,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ alibi=alibi,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ attention_output = attn_outputs[0]
+
+ outputs = attn_outputs[1:]
+
+ layernorm_output = self.post_attention_layernorm(attention_output)
+
+ # Get residual
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = attention_output
+
+ # MLP.
+ output = self.mlp(layernorm_output, residual)
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
diff --git a/petals/src/bloom/from_pretrained.py b/petals/src/bloom/from_pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8bd398791344f9d98916a6c1f99ded162fb2735
--- /dev/null
+++ b/petals/src/bloom/from_pretrained.py
@@ -0,0 +1,86 @@
+"""
+Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
+If necessary, one can rewrite this to implement a different behavior, such as:
+ - loading files from a local data source (e.g. S3)
+ - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
+ - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
+
+"""
+from __future__ import annotations
+
+from typing import Optional, OrderedDict, Union
+
+import torch
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from transformers.modeling_utils import WEIGHTS_NAME
+from transformers.utils.hub import cached_path, hf_bucket_url
+
+from src.bloom import BloomBlock, BloomConfig
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+CLIENT_BRANCH = "main"
+BLOCK_BRANCH_PREFIX = "block_"
+USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
+FORCE_DOWNLOAD = False
+RESUME_DOWNLOAD = False
+LOCAL_FILES_ONLY = False
+
+
+def load_pretrained_block(
+ converted_model_name_or_path: str,
+ block_index: int,
+ config: Optional[BloomConfig] = None,
+ torch_dtype: Union[torch.dtype, str] = "auto",
+ use_auth_token: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+) -> BloomBlock:
+ """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
+ if config is None:
+ config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
+ block = BloomBlock(config, layer_number=block_index)
+ state_dict = _load_state_dict(
+ converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
+ )
+ block.load_state_dict(state_dict)
+
+ if torch_dtype == "auto":
+ with torch.no_grad():
+ for name, param in block.named_parameters():
+ assert name in state_dict, f"{name} not in state dict"
+ param.data = param.data.to(state_dict[name].dtype)
+ else:
+ assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+ block = block.to(dtype=torch_dtype)
+
+ report = block.load_state_dict(state_dict, strict=True)
+ logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
+ return block
+
+
+def _load_state_dict(
+ pretrained_model_name_or_path: str,
+ block_index: Optional[int] = None,
+ use_auth_token: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+) -> OrderedDict[str, torch.Tensor]:
+ revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
+ archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
+
+ # Load from URL or cache if already cached
+ resolved_archive_file = cached_path(
+ archive_file,
+ cache_dir=cache_dir,
+ force_download=FORCE_DOWNLOAD,
+ proxies=None,
+ resume_download=RESUME_DOWNLOAD,
+ local_files_only=LOCAL_FILES_ONLY,
+ use_auth_token=use_auth_token,
+ user_agent=USER_AGENT,
+ )
+ state_dict = torch.load(resolved_archive_file, map_location="cpu")
+ return state_dict
+
+
+DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
diff --git a/petals/src/bloom/model.py b/petals/src/bloom/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5c7d9e55523ffb267c5219adace16f7e8d6762d
--- /dev/null
+++ b/petals/src/bloom/model.py
@@ -0,0 +1,583 @@
+"""
+PyTorch BLOOM model that implements several memory-efficient modes.
+Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
+See commit history for authorship.
+"""
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from hivemind import use_hivemind_log_handler
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
+from transformers.file_utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.bloom.configuration_bloom import BloomConfig
+from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
+from transformers.utils import logging
+
+from src.bloom.block import BloomBlock
+
+use_hivemind_log_handler("in_root_logger")
+logger = logging.get_logger(__file__)
+
+_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
+_CONFIG_FOR_DOC = "BloomConfig"
+_TOKENIZER_FOR_DOC = "BloomTokenizer"
+
+
+BLOOM_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+BLOOM_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
+ sequence tokens in the vocabulary.
+
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
+ `input_ids`.
+
+ Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
+ `past_key_values`).
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
+ BLOOM_START_DOCSTRING,
+)
+class BloomModel(BloomPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
+
+ self.embed_dim = config.hidden_size
+ self.n_head = config.n_head
+
+ # Embedding + LN Embedding
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
+ self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ # Transformer blocks
+ self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
+
+ # Final Layer Norm
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings):
+ self.word_embeddings = new_embeddings
+
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ if position_ids is not None:
+ logger.warning("position_ids are ignored in this bloom implementation")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_head x N x N
+ # head_mask has shape n_layer x batch x n_head x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ # Note: it supports only float32 or bfloat16 inputs
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ current_sequence_length = hidden_states.shape[1]
+ if past_key_values and past_key_values[0]:
+ current_sequence_length += past_key_values[0][0].shape[1]
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache, output_attentions, alibi=None)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ None,
+ attention_mask,
+ head_mask[i],
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=None,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = hidden_states.view(output_shape)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
+ embeddings).
+ """,
+ BLOOM_START_DOCSTRING,
+)
+class BloomForCausalLM(BloomPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = BloomModel(config)
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
+ # only last token for inputs_ids if past is defined in kwargs
+ if past:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+
+ attention_mask = kwargs.get("attention_mask", None)
+ position_ids = kwargs.get("position_ids", None)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+ else:
+ position_ids = None
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past,
+ "use_cache": kwargs.get("use_cache"),
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ }
+
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutputWithCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
+ """
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+ beam_idx at every generation step.
+ """
+ return tuple(
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
+ for layer_past in past
+ )
+
+
+@add_start_docstrings(
+ """
+ The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
+ embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
+ In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
+ """,
+ BLOOM_START_DOCSTRING,
+)
+class LMHead(nn.Module):
+ def __init__(self, config, word_embeddings: nn.Embedding):
+ super().__init__()
+ self.word_embeddings = word_embeddings
+ self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
+
+ @property
+ def in_features(self) -> int:
+ return self.word_embeddings.num_embeddings
+
+ @property
+ def out_features(self) -> int:
+ return self.word_embeddings.embedding_dim
+
+ @property
+ def weight(self):
+ return self.word_embeddings.weight
+
+ @property
+ def bias(self):
+ return None
+
+ def forward(self, hidden_states):
+ word_embeddings = self.word_embeddings.weight
+
+ # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
+ if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
+ lm_logits = self.chunked_forward(hidden_states)
+ else:
+ # Switch dtype in case word_embeddings are fp16/bf16
+ hidden_states = hidden_states.to(word_embeddings.dtype)
+ lm_logits = F.linear(hidden_states, word_embeddings).float()
+ return lm_logits
+
+ def chunked_forward(self, hidden_states):
+ """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
+ chunk_size: provides trade-off between efficiency and extra memory consumption.
+ """
+ assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
+
+ word_embeddings = self.word_embeddings.weight
+ num_embeddings = self.word_embeddings.num_embeddings
+
+ hidden_states = hidden_states.float()
+ output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
+
+ for i in range(0, num_embeddings, self.chunk_size):
+ chunk = word_embeddings[i : i + self.chunk_size].float()
+ output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
+ return output
+
+
+@add_start_docstrings(
+ """
+ The Bloom Model transformer with a sequence classification head on top (linear layer).
+ [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-1) do.
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ BLOOM_START_DOCSTRING,
+)
+class BloomForSequenceClassification(BloomPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = BloomModel(config)
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
+ else:
+ sequence_lengths = -1
+ logger.warning(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
diff --git a/petals/src/bloom/ops.py b/petals/src/bloom/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b84c7c1a8cf8862641046e14a5b20ba3258da544
--- /dev/null
+++ b/petals/src/bloom/ops.py
@@ -0,0 +1,246 @@
+"""
+Utility operations used in the the BLOOM model
+Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
+See commit history for authorship.
+"""
+import math
+
+import torch
+import torch.autograd
+import torch.nn.functional as F
+from torch import nn
+
+
+def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
+ """Split a tensor along its last dimension.
+
+ Args:
+ tensor: ([`torch.tensor`], *required*):
+ input tensor to split
+ num_partitions ([`int`], *required*):
+ number of partitions to split the tensor
+ contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
+ If True, make each chunk contiguous in memory.
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ numerator, denominator = tensor.size()[last_dim], num_partitions
+ if not (numerator % denominator == 0):
+ raise ValueError(f"{numerator} is not divisible by {denominator}")
+ last_dim_size = numerator // denominator
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+
+def attention_mask_func(attention_scores, attention_mask, causal_mask):
+ if attention_mask.dtype == torch.bool:
+ attention_mask_bool = ~attention_mask
+ else:
+ attention_mask_bool = (1 - attention_mask).bool()
+
+ query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
+ padded_causal_mask = (
+ attention_mask_bool[:, None, key_length - query_length : key_length, None]
+ + ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
+ ).bool()
+ padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
+ # Make use of floats
+ return (
+ attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
+ padded_causal_mask,
+ )
+
+
+def build_alibi_tensor(
+ max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
+) -> torch.Tensor:
+ """
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
+ `softmax(l+a) = softmax(l)`. Based on
+ https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
+ Args:
+ Returns tensor shaped (n_head, 1, max_seq_len)
+ max_seq_len: (`int`, *required*):
+ max sequence length
+ n_head: (`int`, *required*):
+ number of heads
+ dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
+ dtype of the output tensor
+ device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
+ device of the output alibi tensor
+ """
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
+ base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
+ powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
+ slopes = torch.pow(base, powers)
+
+ if closest_power_of_2 != n_head:
+ extra_base = torch.tensor(
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
+ )
+ num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
+
+ lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
+ return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
+
+
+def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
+ """
+ Args:
+ Pre-process the alibi tensor for padding.
+ alibi: ([`torch.tensor`], *required*):
+ alibi tensor to pre-process
+ attention_mask: ([`torch.tensor`], *required*):
+ attention mask to pre-process
+ """
+ assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]"
+ unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
+ # ^-- [batch, max_len], values correspond to element indices after removing padding
+ # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
+ alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
+ return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
+
+
+def dropout_add(x, residual, prob, training):
+ """
+ Dropout add function
+
+ Args:
+ x (`torch.tensor`, *required*):
+ input tensor
+ residual (`torch.tensor`, *rquired*):
+ esidual tensor
+ prob (`float`, *required*):
+ dropout probability
+ training (`bool`, *required*):
+ training mode
+ """
+ out = nn.functional.dropout(x, p=prob, training=training)
+ out = residual + out
+ return out
+
+
+def bloom_gelu_forward(x):
+ """
+ Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
+ make the model jitable.
+
+ Args:
+ x (`torch.tensor`, *required*):
+ input hidden states
+ """
+ return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
+
+
+def bloom_gelu_back(g, x):
+ """
+ gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
+ 0.3989423 * x * torch.exp(-0.5 * x * x)
+
+ Args:
+ g (`torch.tensor`, *required*):
+ gradient output tensor
+ x (`torch.tensor`, *required*):
+ input tensor
+ """
+ x = x[0] # x is a tuple of 1 element, needs to unpack it first
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
+ ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
+ return ff * g
+
+
+class GeLUFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input):
+ ctx.save_for_backward(input)
+ return bloom_gelu_forward(input)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input = ctx.saved_tensors
+ tmp = bloom_gelu_back(grad_output, input)
+ return tmp
+
+
+class BloomGelu(nn.Module):
+ """
+ BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
+ torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
+ copied from Megatron-DeepSpeed code and adapted for our needs
+
+ See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
+
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ if self.training:
+ return GeLUFunction.apply(x)
+ else:
+ return bloom_gelu_forward(x)
+
+
+class BloomScaledSoftmax(nn.Module):
+ """
+ fused operation: scaling + mask + softmax
+
+ Args:
+ input_in_fp16 (`bool`, *required*):
+ flag to indicate if input in fp16 data format.
+ input_in_bf16 (`bool`, *required*):
+ flag to indicate if input in bf16 data format.
+ scaled_masked_softmax_fusion (`bool`, *required*):
+ flag to indicate user want to use softmax fusion
+ mask_func (`function`, *required*):
+ mask function to be applied.
+ softmax_in_fp32 (`bool`, *required*):
+ if true, softmax in performed at fp32 precision.
+ scale (`float`, *required*):
+ scaling factor used in input tensor scaling.
+ """
+
+ def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
+ super().__init__()
+ self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
+ self.mask_func = mask_func
+ self.softmax_in_fp32 = softmax_in_fp32
+ self.scale = scale
+
+ if not (self.scale is None or softmax_in_fp32):
+ raise ValueError("softmax should be in fp32 when scaled")
+
+ def forward(self, input, mask, max_positions):
+ input_dtype = input.dtype
+ input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
+ softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
+
+ if self.scale is not None:
+ input = input * self.scale
+
+ if mask is None:
+ mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
+
+ mask = mask.to(input.device)
+ causal_mask = (
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
+ .view(1, 1, max_positions, max_positions)
+ .to(input.device)
+ )
+ mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
+ probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
+
+ if input_in_16bit and self.softmax_in_fp32:
+ probs = probs.to(dtype=input_dtype)
+
+ return probs
diff --git a/petals/src/client/__init__.py b/petals/src/client/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9217b17aa9ffaffc4f1d7e7de68de54c7bfa212
--- /dev/null
+++ b/petals/src/client/__init__.py
@@ -0,0 +1,5 @@
+from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
+from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
+from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
+from src.client.sequence_manager import RemoteSequenceManager
+from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase
diff --git a/petals/src/client/inference_session.py b/petals/src/client/inference_session.py
new file mode 100644
index 0000000000000000000000000000000000000000..812e9533e40bf5e517b740d1b3a9210b886611f2
--- /dev/null
+++ b/petals/src/client/inference_session.py
@@ -0,0 +1,216 @@
+from __future__ import annotations
+
+import asyncio
+import contextlib
+from typing import AsyncIterator, List, Optional
+
+import torch
+from hivemind import (
+ P2P,
+ MSGPackSerializer,
+ anext,
+ deserialize_torch_tensor,
+ get_logger,
+ nested_flatten,
+ serialize_torch_tensor,
+ use_hivemind_log_handler,
+)
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import StubBase
+from hivemind.proto import runtime_pb2
+
+from src.client.sequence_manager import RemoteSequenceManager
+from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
+from src.server.handler import TransformerConnectionHandler
+from src.utils.misc import DUMMY, is_dummy
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class RemoteTransformerBlockInferenceSession:
+ """
+ An interface to a single multi-step *inference* session for a specific remote module on a specific server
+
+ :note: this inference session is *not* fault-tolerant out of the box
+ """
+
+ def __init__(
+ self,
+ uid: ModuleUID,
+ rpc_info: RPCInfo,
+ inputs_queue: asyncio.Queue,
+ outputs_aiter: AsyncIterator,
+ *,
+ max_length: int,
+ points: int = 0,
+ ):
+ self.uid, self.rpc_info = uid, rpc_info
+ self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
+ # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
+ # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
+ self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
+ self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+ self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
+ self.stepped = False
+ self.closed = False
+
+ @classmethod
+ async def _create(
+ cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
+ ) -> RemoteTransformerBlockInferenceSession:
+ """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
+ inputs_queue = asyncio.Queue()
+ outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
+ return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
+
+ @staticmethod
+ async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
+ while True:
+ next_input_message = await asyncio.wait_for(queue.get(), timeout)
+ yield next_input_message
+ if not next_input_message.uid and not next_input_message.tensors:
+ break # this message means "done sending"
+
+ def step(
+ self,
+ new_hidden_states: torch.Tensor,
+ prompts: Optional[torch.Tensor] = None,
+ hypo_ids: Optional[torch.Tensor] = None,
+ ):
+ """
+ Inference step: send a chunk of input tesors and receive a chunk of outputs
+ :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
+ if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
+ """
+ if self.closed:
+ raise Exception("Session is closed, cannot perform step")
+ if prompts is None or is_dummy(prompts):
+ prompts = DUMMY
+ else:
+ assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
+ assert prompts.shape[0] == self.num_blocks
+ assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
+ assert prompts.shape[2] <= new_hidden_states.shape[1]
+ assert prompts.shape[3] == new_hidden_states.shape[2]
+
+ if hypo_ids is None or is_dummy(hypo_ids):
+ hypo_ids = DUMMY
+ else:
+ assert len(hypo_ids) == len(new_hidden_states)
+ assert hypo_ids.dtype == torch.int64
+
+ # serialize inputs and put them into the queue
+ inputs = (new_hidden_states, prompts, hypo_ids)
+ outputs_serialized = RemoteExpertWorker.run_coroutine(
+ self._step(
+ runtime_pb2.ExpertRequest(
+ uid=self.uid,
+ tensors=[
+ serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
+ for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
+ ],
+ metadata=self._serialized_metadata if not self.stepped else None,
+ )
+ )
+ )
+ outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
+ assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
+ return outputs[0]
+
+ async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
+ """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
+ await self._inputs_queue.put(inputs_serialized)
+ self.stepped = True
+ return await anext(self._outputs_stream)
+
+ def close(self):
+ """Finish a given inference session, close the underlying connection"""
+ if self._outputs_stream is None:
+ return # already closed
+ RemoteExpertWorker.run_coroutine(self._aclose_stream())
+ self._outputs_stream = self._inputs_queue = None
+ self.closed = True
+
+ async def _aclose_stream(self):
+ """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
+ if self._outputs_stream is None:
+ return # already closed
+ if self.stepped:
+ await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
+ try:
+ await anext(self._outputs_stream)
+ except StopAsyncIteration:
+ pass
+
+ def __del__(self):
+ self.close()
+
+ def __enter__(self):
+ assert not self.closed
+ return self
+
+ def __exit__(self, *exc_details):
+ self.close()
+
+
+class RemoteSequentialInferenceSession:
+ """
+ An interface to a multi-step *inference* session for a sequence of remote transformer blocks
+ """
+
+ def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata):
+ self.sequence_manager = sequence_manager
+ self.p2p = p2p
+ self.closed = False
+ self.chosen_spans: List[RemoteSpanInfo] = []
+ self.stack = contextlib.ExitStack()
+ self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
+ self.metadata = metadata
+ self.timeout = timeout
+
+ def __enter__(self):
+ assert not self.closed and not self.chosen_spans
+ self.stack.__enter__()
+ # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
+ self.chosen_spans.extend(self.sequence_manager.make_sequence())
+
+ for chosen_span in self.chosen_spans:
+ stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
+ span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
+ inference_session = RemoteExpertWorker.run_coroutine(
+ RemoteTransformerBlockInferenceSession._create(
+ stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
+ )
+ )
+ self.inference_sessions.append(inference_session)
+ self.stack.enter_context(inference_session)
+
+ return self
+
+ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
+ assert not self.closed
+ if torch.is_grad_enabled():
+ logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
+ if prompts is None or is_dummy(prompts):
+ prompts = DUMMY
+ else:
+ assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
+ for session in self.inference_sessions:
+ outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
+ assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
+ inputs = outputs
+ return inputs
+
+ def close(self, *exc_details):
+ """Finish a given inference session, close the underlying connection"""
+ if not self.closed:
+ self.stack.__exit__(*exc_details or (None, None, None))
+ self.inference_sessions.clear()
+ self.closed = True
+
+ def __exit__(self, *exc_details):
+ self.close(*exc_details)
+
+ def __del__(self):
+ self.close()
diff --git a/petals/src/client/remote_forward_backward.py b/petals/src/client/remote_forward_backward.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8713ffa9b32064b5455d78821189d63bcbe1e6f
--- /dev/null
+++ b/petals/src/client/remote_forward_backward.py
@@ -0,0 +1,156 @@
+"""
+Utility functions that call RPC forward or backward on a single remote server
+"""
+import asyncio
+from typing import Iterable, List, Sequence, Tuple
+
+import torch
+from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
+from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
+from hivemind.p2p import StubBase
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.streaming import split_for_streaming
+
+from src.data_structures import ModuleUID, RPCInfo
+
+
+async def run_remote_forward(
+ uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
+) -> Tuple[torch.Tensor, ...]:
+ """
+ Serializes input tensors and calls "rpc_forward" on a remote server.
+ Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
+ but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+ """
+
+ # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
+ # detach to avoid pickling the computation graph
+ assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
+ kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
+
+ # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
+ forward_inputs = (inputs, kwargs)
+
+ # Modify forward_schema to support prompts
+ args_schema, kwargs_schema = rpc_info["forward_schema"]
+ # TODO: rm this assert when support arbitrary number of input tensors
+ assert len(args_schema) == 1 and len(inputs) == 2
+ forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+
+ if not nested_compare(forward_inputs, forward_schema_with_prompts):
+ raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
+ forward_inputs = nested_flatten(forward_inputs)
+ inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
+
+ # Asynchronous serialization
+ loop = asyncio.get_running_loop()
+ serialized_tensors = await asyncio.gather(
+ *(
+ loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+ for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
+ )
+ )
+
+ # call RPC on remote server
+ size = sum(t.element_size() * t.nelement() for t in inputs)
+ if size > MAX_UNARY_PAYLOAD_SIZE:
+ deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
+ else:
+ deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
+
+ return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
+
+
+async def _forward_stream(
+ uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+ split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
+
+ outputs = await stub.rpc_forward_stream(
+ amap_in_executor(
+ lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
+ iter_as_aiter(split),
+ ),
+ )
+
+ tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
+ return await deserialize_tensor_stream(tensors_stream)
+
+
+async def _forward_unary(
+ uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+ outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
+ runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
+ )
+ return [deserialize_torch_tensor(t) for t in outputs.tensors]
+
+
+async def _backward_stream(
+ uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+ split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
+
+ grad_inputs = await stub.rpc_backward_stream(
+ amap_in_executor(
+ lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
+ iter_as_aiter(split),
+ ),
+ )
+ tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
+ return await deserialize_tensor_stream(tensors_stream)
+
+
+async def run_remote_backward(
+ uid: ModuleUID,
+ stub: StubBase,
+ rpc_info: RPCInfo,
+ inputs: torch.Tensor,
+ grad_outputs: List[torch.Tensor],
+ *extra_tensors: torch.Tensor,
+ **kwargs,
+) -> Sequence[torch.Tensor]:
+ """
+ Serializes grad outputs and calls "rpc_backward" on a remote server.
+ Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
+ but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
+ """
+
+ grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
+ inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
+
+ # Modify forward_schema to support prompts
+ args_schema, kwargs_schema = rpc_info["forward_schema"]
+ assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
+ # TODO generalize this
+ prompts_schema = next(iter(args_schema))
+ backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
+
+ # Asynchronous serialization
+ loop = asyncio.get_running_loop()
+ serialized_tensors = await asyncio.gather(
+ *(
+ loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
+ for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+ )
+ )
+
+ size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
+ if size > MAX_UNARY_PAYLOAD_SIZE:
+ deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs)
+ else:
+ deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
+
+ return deserialized_grad_inputs
+
+
+async def _backward_unary(
+ uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
+) -> List[torch.Tensor]:
+ grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+ runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
+ )
+ return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
diff --git a/petals/src/client/remote_generation.py b/petals/src/client/remote_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2be2c949138ab172a64f0f4c7a5770f7ebb0e6c
--- /dev/null
+++ b/petals/src/client/remote_generation.py
@@ -0,0 +1,257 @@
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+
+from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, TopKAlgorithm
+from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
+
+
+class RemoteGenerationMixin:
+ """
+ A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
+ The class exposes can be used for:
+ - *greedy decoding*.
+ - *multinomial sampling*.
+
+ This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
+ """
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ do_sample: Optional[bool] = None,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ bos_token_id: Optional[int] = None,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ max_length: Optional[int] = None,
+ max_new_tokens: Optional[int] = None,
+ decoding_algorithm: Optional[DecodingAlgorithm] = None,
+ provided_constraints: List[ABCBloomConstraint] = [],
+ **model_kwargs,
+ ) -> torch.LongTensor:
+ """
+ Generates sequences of token ids for models with a language modeling head.
+
+ :param inputs: The input tokens to the model.
+ :param do_sample: Whether to sample from the model predictions or take the argmax.
+ :param temperature: The temperature to use for sampling.
+ :param top_k: The number of results to return.
+ :param top_p: The cumulative probability of results to return.
+ :param bos_token_id: The id of the beginning of sentence token.
+ :param eos_token_id: The id of the end of sentence token.
+ :param pad_token_id: The id of the padding token.
+ :param max_new_tokens: The maximum number of tokens to generate.
+ :param decoding_algorithm: The decoding algorithm to use.
+ :param provided_constraints: A list of constraints to use.
+ :param model_kwargs: Additional arguments to pass to the model.
+ """
+
+ assert (
+ model_kwargs.get("logits_processor", None) is None
+ ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
+ assert (
+ model_kwargs.get("logits_wrapper", None) is None
+ ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
+ assert (
+ model_kwargs.get("stopping_criteria", None) is None
+ ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
+ if inputs is not None:
+ assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
+ prefix_length = 0 if inputs is None else inputs.size(1)
+ prefix_length += self.config.pre_seq_len
+
+ bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
+
+ assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
+ if max_length is not None and max_new_tokens is None:
+ max_new_tokens = max_length - prefix_length
+ assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
+ elif max_length is None and max_new_tokens is not None:
+ max_length = prefix_length + max_new_tokens
+
+ if inputs is None:
+ assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
+ inputs = torch.tensor([[bos_token_id]])
+
+ if decoding_algorithm is None:
+ if do_sample:
+ decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
+ else:
+ decoding_algorithm = GreedyAlgorithm()
+
+ constraints = self._get_constraints(
+ inputs=inputs,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ max_new_tokens=max_new_tokens,
+ provided_constraints=provided_constraints,
+ )
+
+ with self.transformer.h.inference_session(max_length=max_length) as sess:
+ outputs = []
+ if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
+ outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
+ else:
+ outputs += [inputs]
+ last_token_id = None
+ seq_idx = outputs[0].size(1)
+ hypo_ids = torch.arange(outputs[0].size(0))
+ while True:
+ embs = self.transformer.word_embeddings(outputs[-1])
+ intermediate_prompts = None
+ if self.config.pre_seq_len > 0 and len(outputs) == 1:
+ prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
+ embs = torch.cat([prompts, embs], dim=1)
+ embs = self.transformer.word_embeddings_layernorm(embs)
+ hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
+ hidden_state = self.transformer.ln_f(hidden_state)
+ lm_logits = self.lm_head(hidden_state)
+
+ for constraint in constraints:
+ lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
+ last_token_id, hypo_ids = decoding_algorithm(lm_logits)
+ if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
+ pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
+ last_token_id = (~pad_token_mask) * inputs[
+ :, seq_idx : seq_idx + 1
+ ] + pad_token_mask * last_token_id
+
+ if torch.all(last_token_id == eos_token_id):
+ break
+
+ outputs.append(last_token_id)
+ seq_idx += 1
+
+ return torch.cat(outputs, dim=-1)
+
+ def greedy_search(
+ self,
+ input_ids: torch.LongTensor,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[int] = None,
+ provided_constraints: List[ABCBloomConstraint] = [],
+ **model_kwargs,
+ ) -> torch.LongTensor:
+ """
+ Generates sequences of token ids for models with a language modeling head. Uses greedy search.
+
+ :param input_ids: The input tokens to the model.
+ :param max_length: The maximum length of the sequence to generate.
+ :param pad_token_id: The id of the padding token.
+ :param eos_token_id: The id of the end of sentence token.
+ :param provided_constraints: A list of constraints to use.
+ """
+ return self.generate(
+ inputs=input_ids,
+ max_new_tokens=max_length,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ decoding_algorithm=GreedyAlgorithm(),
+ provided_constraints=provided_constraints,
+ **model_kwargs,
+ )
+
+ def sample(
+ self,
+ input_ids: torch.LongTensor,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[int] = None,
+ provided_constraints: List[ABCBloomConstraint] = [],
+ **model_kwargs,
+ ) -> torch.LongTensor:
+ """
+ Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
+
+ :param: input_ids: The input tokens to the model.
+ :param: temperature: The temperature to use for sampling.
+ :param: top_k: The number of samples to use for top_k sampling.
+ :param: top_p: The probability of using top_p sampling.
+ :param: max_length: The maximum length of the sequence to generate.
+ :param: pad_token_id: The id of the padding token.
+ :param: eos_token_id: The id of the end of sentence token.
+ :param: provided_constraints: A list of constraints to use.
+ :param: model_kwargs: Additional kwargs to pass to the model.
+ """
+
+ return self.generate(
+ inputs=input_ids,
+ max_new_tokens=max_length,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
+ provided_constraints=provided_constraints,
+ **model_kwargs,
+ )
+
+ def beam_search(
+ self,
+ input_ids: torch.LongTensor,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[int] = None,
+ provided_constraints: List[ABCBloomConstraint] = [],
+ **model_kwargs,
+ ) -> torch.LongTensor:
+ raise NotImplementedError
+
+ def beam_sample(
+ self,
+ input_ids: torch.LongTensor,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[int] = None,
+ provided_constraints: List[ABCBloomConstraint] = [],
+ **model_kwargs,
+ ) -> torch.LongTensor:
+ raise NotImplementedError
+
+ def group_beam_search(
+ self,
+ input_ids: torch.LongTensor,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[int] = None,
+ provided_constraints: List[ABCBloomConstraint] = [],
+ **model_kwargs,
+ ) -> torch.LongTensor:
+ raise NotImplementedError
+
+ def _choose_sample_algorithm(
+ self,
+ temperature: float = 1.0,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ ) -> DecodingAlgorithm:
+ if (top_k is not None) and (top_p is not None):
+ raise ValueError("You have to provide only top_k or top_p for sampling")
+ if top_k:
+ return TopKAlgorithm(top_k, temperature)
+ elif top_p:
+ return NucleusAlgorithm(top_p, temperature)
+
+ def _get_constraints(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ max_new_tokens: Optional[int] = None,
+ provided_constraints: List[ABCBloomConstraint] = [],
+ ) -> List[ABCBloomConstraint]:
+ constraints = []
+ constraints.extend(provided_constraints)
+ if max_new_tokens is not None:
+ constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
+ constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
+ return constraints
diff --git a/petals/src/client/remote_model.py b/petals/src/client/remote_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..158b7a182d6ee627372d4067456c546ebcfba4b6
--- /dev/null
+++ b/petals/src/client/remote_model.py
@@ -0,0 +1,197 @@
+# this code is in active development, interfaces may change
+from typing import Optional, Tuple
+
+import hivemind
+import torch
+import torch.nn as nn
+from hivemind import get_logger, use_hivemind_log_handler
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+
+from src.bloom.model import (
+ BloomConfig,
+ BloomForCausalLM,
+ BloomForSequenceClassification,
+ BloomModel,
+ BloomPreTrainedModel,
+ LMHead,
+)
+from src.client.remote_generation import RemoteGenerationMixin
+from src.client.remote_sequential import RemoteSequential
+from src.utils.misc import DUMMY
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class DistributedBloomConfig(BloomConfig):
+ """
+ A bloom config that contains information about DHT peers.
+ To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
+ """
+
+ initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
+ dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
+ dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
+ chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
+ pre_seq_len: int = 0 # a number of tokens for prompt tuning.
+ tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
+
+
+class DistributedBloomModel(BloomModel):
+ """BloomModel, but all transformer layers are hosted by the swarm"""
+
+ config_class = DistributedBloomConfig
+
+ def __init__(self, config: DistributedBloomConfig):
+ assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
+ assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
+
+ n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
+ super().__init__(config)
+ assert len(self.h) == 0
+ config.n_layer = n_layer
+
+ dht = (
+ config.dht
+ if config.dht is not None
+ else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+ )
+ assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
+ self.h = RemoteSequential(config, dht, config.dht_prefix)
+
+ # Forbid accumulate grads for embeddings and layernorm
+ self.set_requires_grad(False)
+
+ if config.tuning_mode and "ptune" in config.tuning_mode:
+ assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
+ self.pre_seq_len = config.pre_seq_len
+ self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+
+ if config.tuning_mode == "deep_ptune":
+ self.intermediate_prompt_embeddings = nn.Embedding(
+ self.pre_seq_len,
+ config.num_hidden_layers * config.hidden_size
+ # ^-- TODO: should be num_hidden_layers - 1
+ )
+ self.intermediate_prompt_embeddings.weight.data.zero_()
+ elif config.tuning_mode:
+ raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
+
+ def set_requires_grad(self, value):
+ for p in self.parameters():
+ p.requires_grad = value
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
+ prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
+ prompts = self.prompt_embeddings(prefix_tokens)
+
+ if self.config.tuning_mode == "deep_ptune":
+ intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
+ intermediate_prompts = intermediate_prompts.view(
+ batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
+ )
+ intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
+ else:
+ intermediate_prompts = DUMMY
+ return prompts, intermediate_prompts
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
+
+ for k, v in kwargs.items():
+ if not (v is None or v is False):
+ logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+ batch_size = inputs_embeds.shape[0]
+ prompts, intermediate_prompts = self.get_prompt(batch_size)
+ inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+ hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
+ else:
+ hidden_states = self.h(hidden_states)
+
+ # Remove prefix
+ if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+ hidden_states = hidden_states[:, self.pre_seq_len :]
+
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+ hidden_states = hidden_states.view(output_shape)
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=None,
+ hidden_states=None,
+ attentions=None,
+ )
+
+
+class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
+ """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+
+ config_class = DistributedBloomConfig
+
+ def __init__(self, config: DistributedBloomConfig):
+ BloomPreTrainedModel.__init__(self, config)
+ self.transformer = DistributedBloomModel(config)
+ self.lm_head = LMHead(config, self.transformer.word_embeddings)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.transformer.word_embeddings
+
+ def get_output_embeddings(self):
+ if self.config.tie_word_embeddings:
+ return None
+ return self.lm_head
+
+ def set_input_embeddings(self, new_embeddings: nn.Embedding):
+ assert isinstance(new_embeddings, nn.Embedding)
+ self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
+ assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
+
+ def set_output_embeddings(self, new_lm_head: nn.Linear):
+ with torch.no_grad():
+ self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
+ self.lm_head.bias[...] = new_lm_head.bias
+
+
+class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
+ config_class = DistributedBloomConfig
+
+ def __init__(self, config: DistributedBloomConfig):
+ BloomPreTrainedModel.__init__(self, config)
+ self.num_labels = config.num_labels
+
+ self.transformer = DistributedBloomModel(config)
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
diff --git a/petals/src/client/remote_sequential.py b/petals/src/client/remote_sequential.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9e63b2da742662066d457cd8f1bd94669ff33ef
--- /dev/null
+++ b/petals/src/client/remote_sequential.py
@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+from typing import Optional, Union
+
+import torch
+from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from torch import nn
+
+import src
+from src.client.inference_session import RemoteSequentialInferenceSession
+from src.client.sequence_manager import RemoteSequenceManager
+from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
+from src.data_structures import UID_DELIMITER
+from src.utils.misc import DUMMY
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class RemoteSequential(nn.Module):
+ """
+ A sequence of transformer blocks hosted by the swarm.
+ """
+
+ def __init__(
+ self,
+ config: src.DistributedBloomConfig,
+ dht: DHT,
+ dht_prefix: Optional[str] = None,
+ p2p: Optional[P2P] = None,
+ sequence_manager: Optional[RemoteSequenceManager] = None,
+ ):
+ logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
+ super().__init__()
+ self.config = config
+ self.dht = dht
+ self.dht_prefix = dht_prefix or config.dht_prefix
+ self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
+
+ num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
+ block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
+ if sequence_manager is None:
+ logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
+ self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
+ self.is_subsequence = False
+ else:
+ logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
+ self.sequence_manager = sequence_manager
+ assert isinstance(sequence_manager.block_uids, list)
+ self.is_subsequence = self.sequence_manager.block_uids != block_uids
+
+ def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
+ outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
+ return outputs
+
+ def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
+ assert isinstance(ix, (int, slice))
+ if isinstance(ix, int):
+ return RemoteTransformerBlock(
+ self.config,
+ self.dht,
+ dht_prefix=self.dht_prefix,
+ p2p=self.p2p,
+ sequence_manager=self.sequence_manager[ix],
+ )
+ else:
+ return RemoteSequential(
+ self.config,
+ self.dht,
+ dht_prefix=self.dht_prefix,
+ p2p=self.p2p,
+ sequence_manager=self.sequence_manager[ix],
+ )
+
+ def __iter__(self):
+ for block_index in range(len(self)):
+ yield self[block_index]
+
+ def __len__(self):
+ return len(self.sequence_manager)
+
+ def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession:
+ self.sequence_manager.update_()
+ return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs)
+
+ def extra_repr(self) -> str:
+ return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
+
+
+class RemoteTransformerBlock(RemoteSequential):
+ """Single transformer block hosted by swarm
+
+ This class is deprecated and kept for backward compatibility.
+ It will be removed soon in favor of using ``RemoteSequential`` directly.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ assert len(self) == 1, "Remote Block is a sequence size 1"
+
+ def extra_repr(self):
+ return f"{self.sequence_manager.block_uids[0]}"
diff --git a/petals/src/client/sequence_manager.py b/petals/src/client/sequence_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c151634610553b144e060c098abd59b8e3c53cb
--- /dev/null
+++ b/petals/src/client/sequence_manager.py
@@ -0,0 +1,153 @@
+from __future__ import annotations
+
+import random
+import threading
+from typing import List, Optional, Sequence, Tuple, Union
+
+from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.proto import runtime_pb2
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+from src.client.spending_policy import NoSpendingPolicy
+from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
+from src.dht_utils import get_remote_module_infos
+from src.server.handler import TransformerConnectionHandler
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class RemoteSequenceManager:
+ """
+ Keeps and updates the meta-information about which peers host which blocks.
+ In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
+ """
+
+ def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
+ assert len(block_uids) > 0, "Sequences must contain at least one block"
+ self.dht, self.p2p = dht, p2p
+ self.block_uids: List[ModuleUID] = list(block_uids)
+ self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
+ self.spans_by_priority: List[RemoteSpanInfo] = [] # sorted from best to worst
+ self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
+ self.last_update_time: DHTExpiration = -float("inf")
+ self.max_retries = max_retries
+ self._rpc_info = None
+ self.lock_changes = threading.Lock()
+ self.update_()
+
+ for uid, info in zip(self.block_uids, self.block_infos):
+ assert info is not None, f"Found no remote peers for block {uid}"
+ assert self.spans_by_priority and self.spans_containing_block
+
+ def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
+ """
+ Form a sequence of remote servers that collectively serve all consecutive layers
+
+ :param start_index: optional index of the first module in a sequence, default = the first of block_uids
+ :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
+ """
+ end_index = end_index if end_index is not None else len(self.block_uids)
+ span_sequence = []
+ current_index = start_index
+ while current_index < end_index:
+ candidate_spans = self.spans_containing_block[current_index]
+ chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
+
+ assert chosen_span.start <= current_index < chosen_span.end
+ span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
+ current_index = chosen_span.end
+
+ return span_sequence
+
+ def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
+ """Get a RemoteSequenceManager for a sub-sequence of blocks"""
+ assert isinstance(ix, (int, slice))
+ if not isinstance(ix, slice):
+ ix = slice(int(ix), int(ix) + 1, 1)
+ with self.lock_changes:
+ subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
+ subseq.block_infos = self.block_infos[ix]
+ subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
+ subseq.last_update_time = self.last_update_time
+ return subseq
+
+ def update_(self):
+ with self.lock_changes:
+ self.update_block_infos_()
+ self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
+
+ def update_block_infos_(self):
+ new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
+ assert len(new_block_infos) == len(self.block_uids)
+ for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
+ if info is None:
+ logger.warning(f"Found no block info for block {uid}")
+ continue
+ if not isinstance(info, RemoteModuleInfo):
+ logger.warning(f"Unexpected dht entry type for {uid}: {info}")
+ if not info.servers:
+ logger.warning(f"Found no active peers for block {uid}")
+ if info.uid != uid:
+ logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
+ self.block_infos[block_index] = info
+
+ @staticmethod
+ def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
+ closed_spans = []
+ active_spans = {}
+ for block_index, info in enumerate(block_infos):
+ if info is not None:
+ for peer_id, server in info.servers.items():
+ if server.state != ServerState.ONLINE:
+ continue
+ if peer_id not in active_spans:
+ active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
+ else: # peer_id in active_spans
+ active_spans[peer_id].end = block_index + 1
+
+ for peer_id in list(active_spans.keys()):
+ if (
+ info is None
+ or peer_id not in info.servers
+ or info.servers[peer_id].state != ServerState.ONLINE
+ or block_index == len(block_infos) - 1
+ ):
+ closed_spans.append(active_spans.pop(peer_id))
+ assert not active_spans, f"spans: {active_spans}"
+
+ closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
+
+ spans_containing_block = tuple(list() for _ in range(len(block_infos)))
+ for span in closed_spans:
+ for block_index in range(span.start, span.end):
+ spans_containing_block[block_index].append(span)
+
+ return closed_spans, spans_containing_block
+
+ def __len__(self):
+ return len(self.block_uids)
+
+ @property
+ def rpc_info(self):
+ """Return the rpc_info queried from one of the servers that hold the first block"""
+ if self._rpc_info is None:
+ retries = 0
+ for i in range(self.max_retries):
+ try:
+ self.update_()
+ peer_id = random.choice(list(self.block_infos[0].servers.keys()))
+ stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
+ outputs = RemoteExpertWorker.run_coroutine(
+ stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
+ )
+ self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+ break
+ except Exception as e:
+ retries += 1
+ if retries >= self.max_retries:
+ raise e
+ else:
+ logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
+ return self._rpc_info
diff --git a/petals/src/client/sequential_autograd.py b/petals/src/client/sequential_autograd.py
new file mode 100644
index 0000000000000000000000000000000000000000..408e6222111446f077fc6bfcc142ead1229ffd54
--- /dev/null
+++ b/petals/src/client/sequential_autograd.py
@@ -0,0 +1,204 @@
+"""
+A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
+"""
+import asyncio
+import itertools
+import logging
+from typing import List, Optional, Sequence, Tuple
+
+import torch
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+
+from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
+from src.client.sequence_manager import RemoteSequenceManager
+from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
+from src.server.handler import TransformerConnectionHandler
+from src.utils.misc import DUMMY, is_dummy
+
+MAX_TOKENS_IN_BATCH = 1024
+
+
+async def sequential_forward(
+ inputs: torch.Tensor,
+ prompts: torch.Tensor,
+ sequence_manager: RemoteSequenceManager,
+ start_index: int = 0,
+ end_index: Optional[int] = None,
+ min_backoff: float = 1.0,
+) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
+ """
+ Constructs a routing path from to .
+ Performs chained forward for each subsequence of blocks on the path.
+ If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
+ """
+
+ assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
+
+ end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
+ assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
+ assert is_dummy(prompts) or len(prompts) == len(
+ sequence_manager.block_uids
+ ) # should be n_layers - 1 but add extra prompts for convenience
+
+ sequences = sequence_manager.make_sequence(start_index, end_index)
+ intermediate_inputs = []
+ done_sequences = []
+ outputs = inputs
+
+ while len(sequences) > 0:
+ for attempt_no in itertools.count():
+ span = sequences.pop(0)
+ span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+ try:
+ stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+ inputs_and_prompts = [inputs, prompts[span.start : span.end]]
+
+ (outputs,) = await run_remote_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
+
+ assert isinstance(outputs, torch.Tensor)
+ assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
+
+ # Save intermediate inputs and subsequences if the forward is already done for them
+ intermediate_inputs.append(inputs)
+ done_sequences.append(span)
+
+ inputs = outputs
+ break
+ except Exception as e:
+ logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
+ await asyncio.sleep(min_backoff * 2**attempt_no)
+
+ backup_sequences = sequence_manager.make_sequence(span.start)
+ assert backup_sequences[0].start == span.start
+ sequences = backup_sequences
+
+ return outputs, intermediate_inputs, done_sequences
+
+
+async def sequential_backward(
+ grad_outputs: Sequence[torch.Tensor],
+ intermediate_inputs: List[torch.Tensor],
+ prompts: torch.Tensor,
+ forward_sequences: List[RemoteSpanInfo],
+ sequence_manager: RemoteSequenceManager,
+ min_backoff: float = 1.0,
+) -> Sequence[torch.Tensor]:
+ """
+ Performs chained backward for each forward subsequence.
+ If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
+ """
+ assert len(intermediate_inputs) == len(forward_sequences)
+
+ grad_prompts_reversed = []
+ while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
+ for attempt_no in itertools.count():
+ inputs = intermediate_inputs.pop(-1)
+ span = forward_sequences.pop(-1)
+ span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+ try:
+ stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+ grad_outputs, *span_grad_prompts = await run_remote_backward(
+ span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
+ )
+ grad_outputs = [grad_outputs]
+ grad_prompts_reversed.extend(span_grad_prompts)
+ break
+ except Exception as e:
+ logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
+ await asyncio.sleep(min_backoff * 2**attempt_no)
+
+ _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
+ inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
+ )
+ assert len(intermediate_inputs) == len(forward_sequences)
+ assert backup_forward_sequences[0].start == span.start
+ assert backup_forward_sequences[-1].end == span.end
+
+ forward_sequences.extend(backup_forward_sequences)
+ intermediate_inputs.extend(backup_intermediate_inputs)
+
+ # For now, we do not support mixed dummy and grad prompts
+ # Concat in num_layer dimension
+ grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
+ return grad_outputs, grad_prompts
+
+
+async def _gather_forward(input_batches, prompt_batches, sequence_manager):
+ """Wrapper for asyncio.gather to perform parallel sequential forwards"""
+ return await asyncio.gather(
+ *[
+ sequential_forward(input_batch, prompt_batch, sequence_manager)
+ for input_batch, prompt_batch in zip(input_batches, prompt_batches)
+ ]
+ )
+
+
+async def _gather_backward(
+ grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
+):
+ """Wrapper for asyncio.gather to perform parallel sequential backwards"""
+ return await asyncio.gather(
+ *[
+ sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
+ for grad_output, input_batch, prompt_batch, spans in zip(
+ grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
+ )
+ ]
+ )
+
+
+class _RemoteSequentialAutogradFunction(torch.autograd.Function):
+ """
+ PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
+ This function splits input data into batches with and performs efficient parallel processing.
+ """
+
+ @staticmethod
+ def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
+ batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
+ input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
+ if is_dummy(prompts):
+ prompt_batches = [DUMMY] * len(input_batches)
+ else:
+ prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
+
+ sequence_manager.rpc_info # lazy init
+ outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
+ assert len(outputs) == len(input_batches)
+
+ output_batches = [output[0] for output in outputs]
+ intemediate_input_batches = [output[1] for output in outputs]
+ sequences_for_batches = [output[2] for output in outputs]
+
+ ctx.prompt_batches = prompt_batches
+ ctx.sequence_manager = sequence_manager
+ ctx.intemediate_input_batches = intemediate_input_batches
+ ctx.sequences_for_batches = sequences_for_batches
+ return torch.cat(output_batches, dim=0)
+
+ @staticmethod
+ def backward(ctx, grad_outputs: torch.Tensor):
+ intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
+ forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
+ ctx.sequence_manager.rpc_info # lazy init
+
+ batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
+ grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
+ assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
+
+ outputs = RemoteExpertWorker.run_coroutine(
+ _gather_backward(
+ grad_output_batches,
+ intermediate_input_batches,
+ ctx.prompt_batches,
+ forward_sequences,
+ ctx.sequence_manager,
+ )
+ )
+ grad_input_batches = [output[0][0] for output in outputs]
+ grad_prompt_batches = [output[1] for output in outputs]
+
+ grad_inputs = torch.cat(grad_input_batches, dim=0)
+ dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
+ grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
+ return (grad_inputs, grad_prompts, None)
diff --git a/petals/src/client/spending_policy.py b/petals/src/client/spending_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..770d25a59e9660ce3dfc4bbf9d74afd63df9acf2
--- /dev/null
+++ b/petals/src/client/spending_policy.py
@@ -0,0 +1,14 @@
+from abc import ABC, abstractmethod
+
+from hivemind.proto.runtime_pb2 import ExpertRequest
+
+
+class SpendingPolicyBase(ABC):
+ @abstractmethod
+ def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+ pass
+
+
+class NoSpendingPolicy(SpendingPolicyBase):
+ def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+ return 0.0
diff --git a/petals/src/data_structures.py b/petals/src/data_structures.py
new file mode 100644
index 0000000000000000000000000000000000000000..919c8c1c281454807beb103ed7d0a13d41b34b0b
--- /dev/null
+++ b/petals/src/data_structures.py
@@ -0,0 +1,41 @@
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Dict
+
+from hivemind import PeerID
+
+ModuleUID = str
+UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
+CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
+
+
+class ServerState(Enum):
+ OFFLINE = 0
+ JOINING = 1
+ ONLINE = 2
+
+
+@dataclass
+class ServerInfo:
+ state: ServerState
+ throughput: float
+
+
+@dataclass
+class RemoteModuleInfo:
+ """A remote module that is served by one or more servers"""
+
+ uid: ModuleUID
+ servers: Dict[PeerID, ServerInfo]
+
+
+@dataclass
+class RemoteSpanInfo:
+ """A chain of remote blocks served by one specific remote peer"""
+
+ start: int
+ end: int
+ peer_id: PeerID
+
+
+RPCInfo = Dict[str, Any]
diff --git a/petals/src/dht_utils.py b/petals/src/dht_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..78ef08371bcd02d5d6cbad425a15b1b2e9c4fbb9
--- /dev/null
+++ b/petals/src/dht_utils.py
@@ -0,0 +1,180 @@
+"""
+Utilities for declaring and retrieving active model layers using a shared DHT.
+"""
+from __future__ import annotations
+
+import math
+from functools import partial
+from typing import Dict, List, Optional, Sequence, Union
+
+from hivemind.dht import DHT, DHTNode, DHTValue
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import PeerID
+from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
+
+import src
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+def declare_active_modules(
+ dht: DHT,
+ uids: Sequence[ModuleUID],
+ expiration_time: DHTExpiration,
+ state: ServerState,
+ throughput: float,
+ wait: bool = True,
+) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
+ """
+ Declare that your node serves the specified modules; update timestamps if declared previously
+
+ :param uids: a list of module ids to declare
+ :param wait: if True, awaits for declaration to finish, otherwise runs in background
+ :param throughput: specify your performance in terms of compute throughput
+ :param expiration_time: declated modules will be visible for this many seconds
+ :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
+ """
+ if isinstance(uids, str):
+ uids = [uids]
+ if not isinstance(uids, list):
+ uids = list(uids)
+ for uid in uids:
+ assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
+ return dht.run_coroutine(
+ partial(
+ _declare_active_modules,
+ uids=uids,
+ expiration_time=expiration_time,
+ state=state,
+ throughput=throughput,
+ ),
+ return_future=not wait,
+ )
+
+
+async def _declare_active_modules(
+ dht: DHT,
+ node: DHTNode,
+ uids: List[ModuleUID],
+ expiration_time: DHTExpiration,
+ state: ServerState,
+ throughput: float,
+) -> Dict[ModuleUID, bool]:
+ num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
+ return await node.store_many(
+ keys=uids,
+ subkeys=[dht.peer_id.to_base58()] * len(uids),
+ values=[(state.value, throughput)] * len(uids),
+ expiration_time=expiration_time,
+ num_workers=num_workers,
+ )
+
+
+def get_remote_sequence(
+ dht: DHT,
+ start: int,
+ stop: int,
+ config: src.DistributedBloomConfig,
+ dht_prefix: Optional[str] = None,
+ return_future: bool = False,
+) -> Union[src.RemoteSequential, MPFuture]:
+ return RemoteExpertWorker.run_coroutine(
+ _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
+ )
+
+
+async def _get_remote_sequence(
+ dht: DHT,
+ start: int,
+ stop: int,
+ config: src.DistributedBloomConfig,
+ dht_prefix: Optional[str] = None,
+) -> src.RemoteSequential:
+ uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
+ p2p = await dht.replicate_p2p()
+ manager = src.RemoteSequenceManager(dht, uids, p2p)
+ return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
+
+
+def get_remote_module(
+ dht: DHT,
+ uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+ config: src.DistributedBloomConfig,
+ dht_prefix: Optional[str] = None,
+ return_future: bool = False,
+) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
+ """
+ :param uid_or_uids: find one or more modules with these ids from across the DHT
+ :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
+ :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+ :returns: a list of [RemoteTransformerBlock]
+ """
+ return RemoteExpertWorker.run_coroutine(
+ _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future
+ )
+
+
+async def _get_remote_module(
+ dht: DHT,
+ uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+ config: src.DistributedBloomConfig,
+ dht_prefix: Optional[str] = None,
+) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
+ single_uid = isinstance(uid_or_uids, ModuleUID)
+ uids = [uid_or_uids] if single_uid else uid_or_uids
+ p2p = await dht.replicate_p2p()
+ managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
+ modules = [
+ src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
+ ]
+ return modules[0] if single_uid else modules
+
+
+def get_remote_module_infos(
+ dht: DHT,
+ uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+ expiration_time: Optional[DHTExpiration] = None,
+) -> List[Optional[RemoteModuleInfo]]:
+ single_uid = isinstance(uid_or_uids, ModuleUID)
+ uids = [uid_or_uids] if single_uid else uid_or_uids
+ infos = dht.run_coroutine(
+ partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future=False
+ )
+ return infos[0] if single_uid else infos
+
+
+async def _get_remote_module_infos(
+ dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
+) -> List[Optional[RemoteModuleInfo]]:
+ if expiration_time is None:
+ expiration_time = get_dht_time()
+ num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
+ found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
+
+ modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
+ for i, uid in enumerate(uids):
+ metadata = found[uid]
+ if metadata is None or not isinstance(metadata.value, dict):
+ if metadata is not None:
+ logger.error(f"Incorrect metadata for {uid}: {metadata}")
+ continue
+ servers = {}
+ for peer_id, server_info in metadata.value.items():
+ try:
+ peer_id = PeerID.from_base58(peer_id)
+ state, throughput = server_info.value
+ if not (
+ isinstance(state, int)
+ and isinstance(throughput, float)
+ and math.isfinite(throughput)
+ and throughput >= 0.0
+ ):
+ raise ValueError(f"Invalid server info: {server_info}")
+ servers[peer_id] = ServerInfo(ServerState(state), throughput)
+ except (TypeError, ValueError) as e:
+ logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
+ if servers:
+ modules[i] = RemoteModuleInfo(uid, servers)
+ return modules
diff --git a/petals/src/server/__init__.py b/petals/src/server/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/petals/src/server/backend.py b/petals/src/server/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..00f55dddae24857f927c5ceef2a22a287ee21755
--- /dev/null
+++ b/petals/src/server/backend.py
@@ -0,0 +1,84 @@
+"""Code for serving bloom blocks via hivemind-server"""
+from typing import Any, Dict, Optional, Sequence, Tuple
+
+import torch
+from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
+from hivemind.moe.server.module_backend import ModuleBackend
+from hivemind.utils import get_logger
+
+from src.bloom.from_pretrained import BloomBlock
+from src.server.cache import MemoryCache
+from src.server.task_pool import PrioritizedTaskPool
+from src.utils.misc import is_dummy
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class TransformerBackend(ModuleBackend):
+ """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
+
+ def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
+ super().__init__(*args, **kwargs)
+ assert isinstance(self.module, BloomBlock)
+ self.memory_cache = memory_cache
+ for name, param in self.module.named_parameters():
+ assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+ for name, buf in self.module.named_buffers():
+ assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+
+ max_batch_size = self.forward_pool.max_batch_size
+ self.inference_pool = PrioritizedTaskPool(
+ self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
+ )
+ self.forward_pool = PrioritizedTaskPool(
+ self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
+ )
+ self.backward_pool = PrioritizedTaskPool(
+ self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
+ )
+ self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
+ self.inference_schema = (
+ (
+ *self.args_schema,
+ BatchTensorDescriptor((), dtype=self.dtype),
+ BatchTensorDescriptor((), dtype=torch.int64),
+ ),
+ self.kwargs_schema,
+ )
+
+ def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+ with torch.inference_mode():
+ attention_cache_handle = int(cache_metadata[0, 0].item())
+ prefix_length = int(cache_metadata[0, 1].item())
+ (hidden_states, hypo_ids) = inputs
+ assert (
+ hidden_states.ndim == 3
+ ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+
+ with self.memory_cache.use_cache(attention_cache_handle) as cache:
+ assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
+ if not is_dummy(hypo_ids):
+ cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids
+ layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
+ logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}")
+ hidden_states, (new_k, new_v) = self.module.forward(
+ hidden_states, layer_past=layer_past, use_cache=True
+ )
+
+ # todo remove these asserts once we pass all tests
+ new_length = new_v.shape[1]
+ assert new_length > prefix_length
+ assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
+ assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
+ assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
+ cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
+ cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
+ return (hidden_states,)
+
+ def get_pools(self) -> Sequence[PrioritizedTaskPool]:
+ return self.forward_pool, self.backward_pool, self.inference_pool
+
+ def get_info(self) -> Dict[str, Any]:
+ """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+ return dict(super().get_info(), inference_schema=self.inference_schema)
diff --git a/petals/src/server/block_selection.py b/petals/src/server/block_selection.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6352b547da5afd97dbc3d3740f0754da18e38cf
--- /dev/null
+++ b/petals/src/server/block_selection.py
@@ -0,0 +1,106 @@
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+from hivemind import PeerID, get_logger
+
+from src.data_structures import RemoteModuleInfo, ServerState
+
+__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
+
+logger = get_logger(__file__)
+
+
+@dataclass
+class Span:
+ start: int
+ end: int
+ throughput: float
+
+ @property
+ def length(self):
+ return self.end - self.start
+
+ def move_to(self, new_start: int) -> None:
+ self.start, self.end = new_start, new_start + self.length
+
+
+def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
+ spans = {}
+ throughputs = np.zeros(len(module_infos))
+ for block, module in enumerate(module_infos):
+ if module is None:
+ continue
+
+ for peer_id, server in module.servers.items():
+ if server.state == ServerState.OFFLINE:
+ continue
+
+ if peer_id in spans:
+ spans[peer_id].start = min(spans[peer_id].start, block)
+ spans[peer_id].end = max(spans[peer_id].start, block + 1)
+ else:
+ spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput)
+
+ throughputs[block] += server.throughput
+
+ return spans, throughputs
+
+
+def _choose_best_start(throughputs: np.ndarray, num_blocks: int, cur_start: Optional[int]) -> int:
+ options = (
+ (sorted(throughputs[i : i + num_blocks]), i != cur_start, i)
+ for i in range(0, len(throughputs) - num_blocks + 1)
+ )
+ return min(options)[-1]
+
+
+def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
+ _, throughputs = _compute_spans(module_infos)
+ start = _choose_best_start(throughputs, num_blocks, None)
+ return list(range(start, start + num_blocks))
+
+
+def should_choose_other_blocks(
+ local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
+) -> bool:
+ if balance_quality > 1.0:
+ return True # Forces rebalancing on each check (may be used for debugging purposes)
+
+ spans, throughputs = _compute_spans(module_infos)
+ initial_throughput = throughputs.min()
+
+ assert local_peer_id in spans, "Span served by this server is not present in the DHT"
+ local_span = spans[local_peer_id]
+ throughputs[local_span.start : local_span.end] -= local_span.throughput
+
+ new_start = _choose_best_start(throughputs, local_span.length, local_span.start)
+ if local_span.start == new_start:
+ return False # This server is on its best place already
+ local_span.move_to(new_start)
+
+ throughputs[local_span.start : local_span.end] += local_span.throughput
+
+ moved = True
+ while moved:
+ servers = list(spans.keys())
+ np.random.shuffle(servers)
+
+ moved = False
+ for peer_id in servers:
+ span = spans[peer_id]
+ throughputs[span.start : span.end] -= span.throughput
+
+ new_start = _choose_best_start(throughputs, span.length, span.start)
+ if span.start != new_start:
+ span.move_to(new_start)
+ moved = True
+
+ throughputs[span.start : span.end] += span.throughput
+
+ new_throughput = throughputs.min()
+ actual_quality = initial_throughput / new_throughput
+ logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%")
+
+ eps = 1e-6
+ return actual_quality < balance_quality - eps
diff --git a/petals/src/server/cache.py b/petals/src/server/cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e2d213b0b02d0c3313a6663e28b4ecdccafa3e4
--- /dev/null
+++ b/petals/src/server/cache.py
@@ -0,0 +1,143 @@
+"""
+A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.
+
+For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
+
+"""
+import asyncio
+import contextlib
+import ctypes
+import multiprocessing as mp
+import os
+import time
+from typing import AsyncContextManager, Dict, Optional, Union
+
+import hivemind
+import torch
+from hivemind import use_hivemind_log_handler
+from hivemind.utils import TensorDescriptor, get_logger
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+Handle = int
+
+
+class MemoryCache:
+ """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
+
+ def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
+ self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
+ self.device = device
+ self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
+ self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
+ self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
+ self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
+ self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
+ self.runtime_pid = os.getpid()
+
+ self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime
+ self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
+ self._lock_acquire_memory = mp.Lock()
+ self._memory_freed_event = mp.Event()
+
+ @property
+ def current_size_bytes(self) -> int:
+ return self._current_size.value
+
+ @current_size_bytes.setter
+ def current_size_bytes(self, value: int):
+ self._current_size.value = value
+
+ @property
+ def handle_counter(self) -> int:
+ return self._handle_counter.value
+
+ @handle_counter.setter
+ def handle_counter(self, value: int):
+ self._handle_counter.value = value
+
+ @contextlib.asynccontextmanager
+ async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
+ """
+ Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
+
+ :param descr: allocate a tensor of this size, dtype, etc
+
+ :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
+ Furthermore, it can be called concurrently with at most one use_cache call in runtime.
+ """
+ assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
+ assert descr.device is None and descr
+ allocated_handle = None
+ allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
+ loop = asyncio.get_event_loop()
+ try:
+ async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
+ if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
+ await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes)
+ async with hivemind.utils.enter_asynchronously(self._lock_metadata):
+ allocated_handle = int(self.handle_counter)
+ self.current_size_bytes += allocated_size_bytes
+ self.handle_counter += 1 # note: this will eventually overflow and it is okay
+ self._pending_messages.value += 1
+ self._pipe_send.send((allocated_handle, descr))
+
+ yield allocated_handle
+ finally:
+ if allocated_handle is not None:
+ async with hivemind.utils.enter_asynchronously(self._lock_metadata):
+ self._pending_messages.value += 1
+ self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle
+ self.current_size_bytes -= allocated_size_bytes
+ self._memory_freed_event.set()
+
+ def _wait_until_available(self, allocated_size_bytes: int, timeout: Optional[float] = None):
+ # note: this function should only be called inside _lock_acquire_memory!
+ if allocated_size_bytes > self.max_size_bytes:
+ raise AllocationFailed(
+ f"Could not allocate {allocated_size_bytes} bytes, max cache size = {self.max_size_bytes} bytes"
+ )
+ deadline = None if timeout is None else time.perf_counter() + timeout
+ while self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
+ remaining_time = deadline - time.perf_counter() if timeout is not None else None
+ if not self._memory_freed_event.wait(remaining_time):
+ raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in {timeout} seconds")
+ self._memory_freed_event.clear()
+
+ @contextlib.contextmanager
+ def use_cache(self, handle: Handle) -> torch.Tensor:
+ """
+ Return a tensor that was previously allocated with try_allocate_cache,
+
+ :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
+ However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
+ """
+ assert os.getpid() == self.runtime_pid
+ # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
+
+ with self._lock_metadata:
+ if self._allocated_tensors is None:
+ self._allocated_tensors = {}
+
+ # read creation/deletion requests from connection handlers
+ for i in range(int(self._pending_messages.value)):
+ recv_handle, recv_data = self._pipe_recv.recv()
+ self._pending_messages.value -= 1
+ if isinstance(recv_data, TensorDescriptor):
+ self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
+ elif recv_data is None:
+ if recv_handle not in self._allocated_tensors:
+ logger.warning(
+ f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
+ )
+ self._allocated_tensors.pop(recv_handle, None)
+ else:
+ logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
+
+ assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
+ yield self._allocated_tensors[handle]
+
+
+class AllocationFailed(Exception):
+ pass
diff --git a/petals/src/server/handler.py b/petals/src/server/handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c366e3437009e4e059e0ca5ab0e4db65f23b124
--- /dev/null
+++ b/petals/src/server/handler.py
@@ -0,0 +1,421 @@
+import contextlib
+from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
+
+import torch
+from hivemind import (
+ DHT,
+ MSGPackSerializer,
+ P2PContext,
+ TensorDescriptor,
+ deserialize_tensor_stream,
+ deserialize_torch_tensor,
+ nested_flatten,
+ serialize_torch_tensor,
+)
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
+from hivemind.proto import runtime_pb2
+from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
+from hivemind.utils.logging import get_logger
+from hivemind.utils.streaming import split_for_streaming
+
+from src.data_structures import CHAIN_DELIMITER, ModuleUID
+from src.server.backend import TransformerBackend
+from src.server.task_pool import PrioritizedTaskPool
+from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
+from src.utils.misc import DUMMY, is_dummy
+
+logger = get_logger(__file__)
+
+
+class TransformerConnectionHandler(ConnectionHandler):
+ """Handles three request types: forward, backward and forward-incremental (inference)"""
+
+ module_backends: Dict[ModuleUID, TransformerBackend]
+
+ def __init__(
+ self,
+ dht: DHT,
+ module_backends: Dict[str, TransformerBackend],
+ inference_max_length: int,
+ task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
+ ):
+ super().__init__(dht, module_backends)
+ for module_backend in self.module_backends.values():
+ assert isinstance(module_backend, TransformerBackend)
+ self.inference_max_length = inference_max_length
+ self._prioritizer = task_prioritizer
+
+ async def _gather_inputs(
+ self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+ ) -> Tuple[str, List[torch.Tensor], Dict]:
+ block_uid, metadata = None, None
+
+ def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
+ nonlocal block_uid, metadata
+
+ if block_uid is None:
+ block_uid = req.uid
+ elif block_uid != req.uid:
+ raise ValueError("Block uids differ in one request")
+
+ if metadata is None:
+ metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}
+
+ return req.tensors
+
+ tensors_stream = amap_in_executor(_unpack, requests)
+ inputs = await deserialize_tensor_stream(tensors_stream)
+ assert isinstance(block_uid, str) and isinstance(metadata, dict)
+ return block_uid, inputs, metadata
+
+ async def rpc_inference(
+ self,
+ requests: AsyncIterator[runtime_pb2.ExpertRequest],
+ context: P2PContext,
+ ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+ """Compute a single step of inference using attention cache; update attention cache accordingly."""
+ try:
+ logger.debug("Opened rpc_inference()")
+ request = await anext(requests)
+ requested_uids = self._check_uids(request.uid)
+ metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+ max_length = metadata.get("max_length")
+ points = metadata.get("points", 0)
+
+ if not requested_uids:
+ raise ValueError("User must specify at least one block for inference, but got none")
+ assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
+ assert isinstance(
+ points, (float, int)
+ ), f"rpc_inference should have number of points as a number or None, got {points}"
+ if not 0 <= max_length <= self.inference_max_length:
+ raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
+
+ point_per_piece = points / max_length if max_length > 0 else 0.0
+ batch_size = request.tensors[0].size[0] if request.tensors else 1
+
+ cache_metadata = torch.tensor(
+ [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
+ ) # [cache_handle, prefix_length]
+ prefix_length = 0
+
+ async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
+ assert len(cache_handles) == len(requested_backends)
+ while request.tensors: # iterate while user is willing to supply tensors
+ hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+
+ # Cast inputs to backend dtype
+ hidden_states = hidden_states.to(requested_backends[0].dtype)
+ assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
+
+ # parse deep prompts (optional argument)
+ if prompts is None or is_dummy(prompts) or is_dummy(prompts):
+ prompts = [DUMMY] * len(requested_backends)
+ else:
+ prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+ if not (len(requested_backends) == len(prompts)):
+ raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
+
+ length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
+ if prefix_length + length_increment > max_length:
+ raise ValueError(
+ f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
+ f" exceeds pre-allocated maximum {max_length}"
+ )
+
+ # run request tensors through all requested modules, update caches
+ for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
+ if not is_dummy(prompt):
+ hidden_states[:, : prompt.shape[1]] += prompt
+
+ cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
+ assert isinstance(
+ hidden_states, torch.Tensor
+ ), f"hidden states must be tensor, got {type(hidden_states)}"
+ assert (
+ hidden_states.ndim == 3
+ ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+ assert isinstance(
+ backend.inference_pool, PrioritizedTaskPool
+ ), "petals support only prioritized pools"
+ priority = self._prioritizer.prioritize(
+ cache_metadata,
+ hidden_states,
+ hypo_ids,
+ points=point_per_piece / len(requested_backends),
+ backend=backend,
+ type="inference",
+ )
+ (hidden_states,) = await backend.inference_pool.submit_task(
+ cache_metadata, hidden_states, hypo_ids, priority=priority
+ )
+
+ # serialize and send last layer outputs
+ yield runtime_pb2.ExpertResponse(
+ tensors=[
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+ for result, proto in zip(
+ (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
+ )
+ ]
+ )
+
+ # prepare for next step
+ prefix_length += hidden_states.shape[1]
+ request = await (anext(requests))
+ finally:
+ logger.debug("Closed rpc_inference()")
+
+ async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+ # Parse request and prepare backends
+ flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+ requested_uids = self._check_uids(request.uid)
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+ metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+ points = metadata.get("points", 0)
+ assert isinstance(
+ points, (float, int)
+ ), f"rpc_forward should have number of points as number or None, got {points}"
+
+ hidden_states = await _rpc_forward(
+ *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+ )
+ assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
+
+ # Serialize output and respond to client
+ return runtime_pb2.ExpertResponse(
+ tensors=[
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+ for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
+ ]
+ )
+
+ async def rpc_forward_stream(
+ self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+ ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+ # Parse requests and prepare backends
+ uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
+ requested_uids = self._check_uids(uid_str)
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+ points = metadata.get("points", 0)
+ assert isinstance(
+ points, (float, int)
+ ), f"rpc_forward_stream should have number of points as number or None, got {points}"
+
+ hidden_states = await _rpc_forward(
+ *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+ )
+ assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
+
+ # Serialize the overall output
+ serialized_output = [
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+ for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
+ ]
+
+ # Split the serialized_output for streaming and respond to client
+ output_split = [
+ part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+ ]
+ async for part in as_aiter(*output_split):
+ yield runtime_pb2.ExpertResponse(tensors=[part])
+
+ async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+ # Parse requests and prepare backends
+ flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+ requested_uids = self._check_uids(request.uid)
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+ metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
+ points = metadata.get("points", 0)
+ assert isinstance(
+ points, (float, int)
+ ), f"rpc_backward should have number of points as number or None, got {points}"
+
+ grads = await _rpc_backward(
+ *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+ )
+
+ # Modify grad_inputs_schema to support grad_prompts
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
+
+ grad_inputs_schema_with_prompts = (
+ requested_backends[0].args_schema * len(grads),
+ requested_backends[0].kwargs_schema,
+ ) # TODO generalize
+
+ # Serialize the overall grad_input and respond
+ return runtime_pb2.ExpertResponse(
+ tensors=[
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+ for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
+ ]
+ )
+
+ async def rpc_backward_stream(
+ self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+ ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+
+ uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
+ requested_uids = self._check_uids(uids_header)
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+ points = metadata.get("points", 0)
+ assert isinstance(
+ points, (float, int)
+ ), f"rpc_backward_stream should have number of points as number or None, got {points}"
+
+ grads = await _rpc_backward(
+ *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
+ )
+
+ # Modify grad_inputs_schema to support grad_prompts
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
+ grad_inputs_schema_with_prompts = (
+ requested_backends[0].args_schema * len(grads),
+ requested_backends[0].kwargs_schema,
+ ) # TODO generalize
+
+ # Serialize the overall grad_inputs
+ serialized_grad_inputs = [
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
+ for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
+ ]
+ # Split the serialized_grad_inputs for streaming and respond
+ output_split = [
+ part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+ ]
+
+ async for part in as_aiter(*output_split):
+ yield runtime_pb2.ExpertResponse(tensors=[part])
+
+ def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
+ """Check that the first request to rpc_inference is valid"""
+ uids = (uids or "").split(CHAIN_DELIMITER)
+ if not uids:
+ raise RuntimeError("User did not provide any uids")
+ for uid in uids:
+ if uid not in self.module_backends:
+ raise RuntimeError(f"Remote peer does not serve {uid}")
+ return tuple(uids)
+
+ @contextlib.asynccontextmanager
+ async def _allocate_caches(
+ self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
+ ) -> Sequence[int]:
+ """Allocate memory caches for each transformer block, return cache handles"""
+ async with contextlib.AsyncExitStack() as stack:
+ handles = []
+ for backend in backends:
+ num_heads = backend.module.self_attention.num_heads
+ head_dim = backend.module.self_attention.head_dim
+
+ cache_descriptor = TensorDescriptor(
+ size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
+ )
+ # [key_or_value, batch_size, max_length, num_heads, head_dim]
+
+ handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
+
+ yield handles
+
+
+async def _rpc_forward(
+ *flat_tensors: torch.Tensor,
+ requested_backends: Sequence[TransformerBackend],
+ prioritizer: TaskPrioritizerBase,
+ points: int = 0,
+) -> torch.Tensor:
+ """
+ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
+
+ :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
+ :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
+ :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
+ :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
+ """
+ hidden_states, prompts = flat_tensors
+ dtype = requested_backends[0].dtype
+ # check parse input tensors and cast dtypes
+ hidden_states = hidden_states.to(dtype)
+ assert hidden_states.ndim == 3
+ if prompts is None or is_dummy(prompts):
+ prompts = [DUMMY] * len(requested_backends)
+ else:
+ prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+ # Run a chain of requested backends
+ for backend, prompt in zip(requested_backends, prompts):
+ if not is_dummy(prompt):
+ hidden_states[:, : prompt.shape[1]] += prompt
+
+ assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+ priority = prioritizer.prioritize(
+ hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
+ )
+ (hidden_states,) = await backend.forward_pool.submit_task(
+ hidden_states,
+ priority=priority,
+ )
+ assert isinstance(hidden_states, torch.Tensor)
+ assert (
+ hidden_states.ndim == 3
+ ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+
+ # Serialize the overall output
+ return hidden_states
+
+
+async def _rpc_backward(
+ *flat_tensors: torch.Tensor,
+ requested_backends: Sequence[TransformerBackend],
+ prioritizer: TaskPrioritizerBase,
+ points: int = 0,
+) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
+ inputs, grad_outputs, prompts = flat_tensors
+ # Cast inputs & grad outputs to backend dtype
+ inputs = inputs.to(requested_backends[0].dtype)
+ grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
+
+ if prompts is None or is_dummy(prompts):
+ prompts = [DUMMY] * len(requested_backends)
+ else:
+ prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+
+ # Run a forward chain to collect intermediate inputs
+ # Note that we do not forward for the last module since we do not need its output
+ inter_inputs = []
+ for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
+ assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+ if not is_dummy(prompt):
+ inputs[:, : prompt.shape[1]] += prompt
+ inter_inputs.append(inputs)
+ assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+ priority = prioritizer.prioritize(
+ inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
+ )
+ (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
+
+ assert isinstance(inputs, torch.Tensor)
+
+ if not is_dummy(prompts[-1]):
+ inputs[:, : prompts[-1].shape[1]] += prompts[-1]
+ inter_inputs.append(inputs)
+
+ assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
+ grad_prompts_reversed = []
+ # Run a chain of requested backends
+ for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
+ assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
+ priority = prioritizer.prioritize(
+ inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
+ )
+ (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
+
+ assert isinstance(grad_outputs, torch.Tensor)
+ if not is_dummy(prompt):
+ grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
+
+ grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
+ return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
diff --git a/petals/src/server/runtime.py b/petals/src/server/runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..11547aa9e344460e730a78c5568920e2413034fc
--- /dev/null
+++ b/petals/src/server/runtime.py
@@ -0,0 +1,198 @@
+import multiprocessing as mp
+import multiprocessing.pool
+import threading
+from collections import defaultdict
+from itertools import chain
+from queue import SimpleQueue
+from selectors import EVENT_READ, DefaultSelector
+from statistics import mean
+from time import time
+from typing import Dict, NamedTuple, Optional
+
+import torch
+from hivemind.moe.server.module_backend import ModuleBackend
+from hivemind.utils import get_logger
+from prefetch_generator import BackgroundGenerator
+
+logger = get_logger(__name__)
+
+
+class Runtime(threading.Thread):
+ """
+ A group of processes that processes incoming requests for multiple module backends on a shared device.
+ Runtime is usually created and managed by Server, humans need not apply.
+
+ For debugging, you can start runtime manually with .start() or .run()
+
+ >>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
+ >>> runtime = Runtime(module_backends)
+ >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run()
+ >>> runtime.ready.wait() # await for runtime to load all blocks on device and create request pools
+ >>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
+ >>> print("Returned:", future.result())
+ >>> runtime.shutdown()
+
+ :param module_backends: a dict [block uid -> ModuleBackend]
+ :param prefetch_batches: form up to this many batches in advance
+ :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
+ :param device: if specified, moves all blocks and data to this device via .to(device=device).
+ If you want to manually specify devices for each block (in their forward pass), leave device=None (default)
+
+ :param stats_report_interval: interval to collect and log statistics about runtime performance
+ """
+
+ SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
+
+ def __init__(
+ self,
+ module_backends: Dict[str, ModuleBackend],
+ prefetch_batches: int = 1,
+ sender_threads: int = 1,
+ device: torch.device = None,
+ stats_report_interval: Optional[int] = None,
+ ):
+ super().__init__()
+ self.module_backends = module_backends
+ self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
+ self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
+ self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
+ self.shutdown_trigger = mp.Event()
+ self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
+
+ self.stats_report_interval = stats_report_interval
+ if self.stats_report_interval is not None:
+ self.stats_reporter = StatsReporter(self.stats_report_interval)
+
+ def run(self):
+ for pool in self.pools:
+ if not pool.is_alive():
+ pool.start()
+ if self.device is not None:
+ for backend in self.module_backends.values():
+ backend.module.to(self.device)
+
+ with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
+ try:
+ self.ready.set()
+ if self.stats_report_interval is not None:
+ self.stats_reporter.start()
+ logger.info("Started")
+
+ batch_iterator = self.iterate_minibatches_from_pools()
+ if self.prefetch_batches > 0:
+ batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
+
+ for pool, batch_index, batch in batch_iterator:
+ logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
+
+ start = time()
+ try:
+ outputs = pool.process_func(*batch)
+ output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
+
+ batch_processing_time = time() - start
+
+ batch_size = outputs[0].size(0)
+ logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
+
+ if self.stats_report_interval is not None:
+ self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
+
+ except KeyboardInterrupt:
+ raise
+ except BaseException as exception:
+ logger.exception(f"Caught {exception}, attempting to recover")
+ output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
+
+ finally:
+ if not self.shutdown_trigger.is_set():
+ self.shutdown()
+
+ def shutdown(self):
+ """Gracefully terminate a running runtime."""
+ logger.info("Shutting down")
+ self.ready.clear()
+
+ if self.stats_report_interval is not None:
+ self.stats_reporter.stop.set()
+ self.stats_reporter.join()
+
+ logger.debug("Terminating pools")
+ for pool in self.pools:
+ if pool.is_alive():
+ pool.shutdown()
+ logger.debug("Pools terminated")
+
+ # trigger background thread to shutdown
+ self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
+ self.shutdown_trigger.set()
+
+ def iterate_minibatches_from_pools(self, timeout=None):
+ """
+ Chooses pool according to priority, then copies exposed batch and frees the buffer
+ """
+ with DefaultSelector() as selector:
+ for pool in self.pools:
+ selector.register(pool.batch_receiver, EVENT_READ, pool)
+ selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
+
+ while True:
+ # wait until at least one batch_receiver becomes available
+ logger.debug("Waiting for inputs from task pools")
+ ready_fds = selector.select()
+ ready_objects = {key.data for (key, events) in ready_fds}
+ if self.SHUTDOWN_TRIGGER in ready_objects:
+ break # someone asked us to shutdown, break from the loop
+
+ logger.debug("Choosing the pool with first priority")
+
+ pool = min(ready_objects, key=lambda pool: pool.priority)
+
+ logger.debug(f"Loading batch from {pool.name}")
+ batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
+ logger.debug(f"Loaded batch from {pool.name}")
+ yield pool, batch_index, batch_tensors
+
+
+BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
+
+
+class StatsReporter(threading.Thread):
+ def __init__(self, report_interval: int):
+ super().__init__()
+ self.report_interval = report_interval
+ self.stop = threading.Event()
+ self.stats_queue = SimpleQueue()
+
+ def run(self):
+ while not self.stop.wait(self.report_interval):
+ pool_batch_stats = defaultdict(list)
+ while not self.stats_queue.empty():
+ pool_uid, batch_stats = self.stats_queue.get()
+ pool_batch_stats[pool_uid].append(batch_stats)
+
+ total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
+ logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
+ for pool_uid, pool_stats in pool_batch_stats.items():
+ total_batches = len(pool_stats)
+ total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
+ avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
+ total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
+ batches_to_time = total_batches / total_time
+ batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
+
+ examples_to_time = total_examples / total_time
+ example_performance = f"{examples_to_time:.2f} " + (
+ "examples/s" if examples_to_time > 1 else "s/example"
+ )
+
+ logger.info(
+ f"{pool_uid}: "
+ f"{total_batches} batches ({batch_performance}), "
+ f"{total_examples} examples ({example_performance}), "
+ f"avg batch size {avg_batch_size:.2f}"
+ )
+
+ def report_stats(self, pool_uid, batch_size, processing_time):
+ batch_stats = BatchStats(batch_size, processing_time)
+ self.stats_queue.put_nowait((pool_uid, batch_stats))
diff --git a/petals/src/server/server.py b/petals/src/server/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..174762dc2ccc022ad9a79441962963492e221c18
--- /dev/null
+++ b/petals/src/server/server.py
@@ -0,0 +1,475 @@
+from __future__ import annotations
+
+import gc
+import multiprocessing as mp
+import random
+import threading
+import time
+from typing import Dict, List, Optional, Sequence, Union
+
+import numpy as np
+import psutil
+import torch
+from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
+from hivemind.moe.server.layers import add_custom_models_from_file
+from hivemind.moe.server.runtime import Runtime
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+from src import BloomConfig, declare_active_modules
+from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
+from src.dht_utils import get_remote_module_infos
+from src.server import block_selection
+from src.server.backend import TransformerBackend
+from src.server.cache import MemoryCache
+from src.server.handler import TransformerConnectionHandler
+from src.server.throughput import get_host_throughput
+from src.utils.convert_8bit import replace_8bit_linear
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class Server(threading.Thread):
+ """
+ Runs ModuleContainer, periodically checks that the network is balanced,
+ restarts the ModuleContainer with other layers if the imbalance is significant
+ """
+
+ def __init__(
+ self,
+ prefix: Optional[str],
+ converted_model_name_or_path: str,
+ throughput: Union[float, str],
+ num_blocks: Optional[int] = None,
+ block_indices: Optional[str] = None,
+ num_handlers: int = 8,
+ min_batch_size: int = 1,
+ max_batch_size: int = 4096,
+ inference_max_length: int = 4096,
+ torch_dtype: str = "auto",
+ revision: str = "main",
+ cache_dir: Optional[str] = None,
+ attn_cache_size: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ initial_peers: Sequence[str] = (),
+ compression=CompressionType.NONE,
+ stats_report_interval: Optional[int] = None,
+ custom_module_path=None,
+ update_period: float = 30,
+ expiration: Optional[float] = None,
+ prefetch_batches: int = 1,
+ sender_threads: int = 1,
+ balance_quality: float = 0.75,
+ mean_balance_check_period: float = 60,
+ mean_block_selection_delay: float = 0.5,
+ use_auth_token: Optional[str] = None,
+ load_in_8bit: bool = False,
+ *,
+ start: bool,
+ **kwargs,
+ ):
+ """Create a server with one or more bloom blocks. See run_server.py for documentation."""
+
+ super().__init__()
+
+ self.converted_model_name_or_path = converted_model_name_or_path
+ self.num_handlers = num_handlers
+ self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
+ self.inference_max_length = inference_max_length
+ self.cache_dir = cache_dir
+ self.attn_cache_size = attn_cache_size
+ self.compression = compression
+ self.stats_report_interval, self.update_period = stats_report_interval, update_period
+ self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
+ self.use_auth_token = use_auth_token
+ self.load_in_8bit = load_in_8bit
+
+ if custom_module_path is not None:
+ add_custom_models_from_file(custom_module_path)
+
+ if prefix is None:
+ prefix = converted_model_name_or_path
+ assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
+ f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
+ f"Please specify --prefix manually when starting a server"
+ )
+ logger.info(f"Automatic dht prefix: {prefix}")
+ self.prefix = prefix
+
+ if expiration is None:
+ expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
+ self.expiration = expiration
+
+ self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
+ visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
+ logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+ self.device = device
+
+ self.memory_cache = MemoryCache(device, attn_cache_size)
+
+ assert isinstance(throughput, float) or throughput in ["auto", "eval"]
+ if throughput in ["auto", "eval"]:
+ throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
+ self.throughput = throughput
+
+ if isinstance(torch_dtype, str):
+ torch_dtype = DTYPE_MAP[torch_dtype]
+ assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+ self.torch_dtype = torch_dtype
+
+ self.block_config = BloomConfig.from_pretrained(
+ converted_model_name_or_path,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
+
+ assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
+ if block_indices is not None:
+ try:
+ first_block_index, last_block_index = block_indices.split(":")
+ first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
+ except Exception as e:
+ logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
+ raise
+ block_indices = range(first_block_index, last_block_index)
+ self.strict_block_indices, self.num_blocks = block_indices, num_blocks
+ self.balance_quality = balance_quality
+ self.mean_balance_check_period = mean_balance_check_period
+ self.mean_block_selection_delay = mean_block_selection_delay
+
+ self.stop = threading.Event()
+ if start:
+ self.start()
+
+ def run(self):
+ while True:
+ block_indices = self._choose_blocks()
+ self.module_container = ModuleContainer.create(
+ dht=self.dht,
+ prefix=self.prefix,
+ converted_model_name_or_path=self.converted_model_name_or_path,
+ block_config=self.block_config,
+ memory_cache=self.memory_cache,
+ throughput=self.throughput,
+ block_indices=block_indices,
+ num_handlers=self.num_handlers,
+ min_batch_size=self.min_batch_size,
+ max_batch_size=self.max_batch_size,
+ inference_max_length=self.inference_max_length,
+ torch_dtype=self.torch_dtype,
+ cache_dir=self.cache_dir,
+ device=self.device,
+ compression=self.compression,
+ stats_report_interval=self.stats_report_interval,
+ update_period=self.update_period,
+ expiration=self.expiration,
+ prefetch_batches=self.prefetch_batches,
+ sender_threads=self.sender_threads,
+ use_auth_token=self.use_auth_token,
+ load_in_8bit=self.load_in_8bit,
+ start=True,
+ )
+ try:
+ self.module_container.ready.wait()
+
+ while True:
+ timeout = random.random() * 2 * self.mean_balance_check_period
+ # TODO: Follow ModuleContainer status (to restart/stop if it crashes)
+ if self.stop.wait(timeout):
+ return
+
+ if self._should_choose_other_blocks():
+ logger.info("Swarm is imbalanced, server will load other blocks")
+ break # Stop serving this set of modules
+ finally:
+ self.module_container.shutdown()
+
+ self._clean_memory_and_fds()
+
+ def _clean_memory_and_fds(self):
+ del self.module_container
+ gc.collect() # In particular, this closes unused file descriptors
+
+ cur_proc = psutil.Process()
+ num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
+ logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
+
+ def _choose_blocks(self) -> List[int]:
+ if self.strict_block_indices is not None:
+ return self.strict_block_indices
+ assert self.num_blocks is not None
+
+ # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
+ # this delay decreases the probability of a race condition while choosing the best blocks to serve.
+ time.sleep(random.random() * 2 * self.mean_block_selection_delay)
+ module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
+ return block_selection.choose_best_blocks(self.num_blocks, module_infos)
+
+ def _should_choose_other_blocks(self) -> bool:
+ if self.strict_block_indices is not None:
+ return False
+
+ module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
+ return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
+
+ def shutdown(self):
+ self.stop.set()
+
+ self.dht.shutdown()
+ self.dht.join()
+
+
+class ModuleContainer(threading.Thread):
+ """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
+
+ def __init__(
+ self,
+ dht: DHT,
+ module_backends: Dict[str, TransformerBackend],
+ *,
+ inference_max_length: int,
+ num_connection_handlers: int,
+ throughput: float,
+ update_period: float,
+ expiration: Optional[float] = None,
+ start: bool,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.dht, self.module_backends = dht, module_backends
+ self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
+ self.conn_handlers = [
+ TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
+ for _ in range(num_connection_handlers)
+ ]
+ self.runtime = Runtime(self.module_backends, **kwargs)
+ self.dht_handler_thread = ModuleAnnouncerThread(
+ self.module_backends,
+ dht,
+ throughput=throughput,
+ update_period=update_period,
+ expiration=expiration,
+ daemon=True,
+ )
+ self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
+
+ if start:
+ self.run_in_background(await_ready=True)
+
+ def run(self):
+ """
+ Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
+ runs Runtime (self.runtime) to process incoming requests.
+ """
+ logger.info(f"Serving {len(self.module_backends)} blocks:")
+ for expert_name, backend in self.module_backends.items():
+ num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
+ logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
+
+ if not self.dht.is_alive():
+ self.dht.run_in_background(await_ready=True)
+
+ if self.module_backends:
+ self.dht_handler_thread.start()
+
+ if self.checkpoint_saver is not None:
+ self.checkpoint_saver.start()
+
+ for handler in self.conn_handlers:
+ handler.run_in_background()
+
+ self.runtime.run()
+
+ # noinspection PyMethodOverriding
+ @classmethod
+ def create(
+ cls,
+ *,
+ dht: DHT,
+ prefix: str,
+ converted_model_name_or_path: str,
+ block_config: BloomConfig,
+ memory_cache: MemoryCache,
+ throughput: float,
+ block_indices: List[int],
+ num_handlers: Optional[int],
+ min_batch_size: int,
+ max_batch_size: int,
+ inference_max_length: int,
+ torch_dtype: torch.dtype,
+ cache_dir: Optional[str],
+ device: Union[str, torch.device],
+ compression: CompressionType,
+ stats_report_interval: Optional[int],
+ update_period: float,
+ expiration: Optional[float],
+ prefetch_batches: int,
+ sender_threads: int,
+ use_auth_token: Optional[str],
+ load_in_8bit: bool,
+ start: bool,
+ ) -> ModuleContainer:
+ module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
+ declare_active_modules(
+ dht,
+ module_uids,
+ expiration_time=get_dht_time() + expiration,
+ state=ServerState.JOINING,
+ throughput=throughput,
+ )
+ logger.info(f"Announced that blocks {block_indices} are joining")
+
+ blocks = {}
+ for module_uid, block_index in zip(module_uids, block_indices):
+ block = load_pretrained_block(
+ converted_model_name_or_path,
+ block_index,
+ block_config,
+ torch_dtype=torch_dtype,
+ use_auth_token=use_auth_token,
+ cache_dir=cache_dir,
+ )
+
+ if load_in_8bit:
+ dtype = block.input_layernorm.weight.dtype
+ block = replace_8bit_linear(block)
+
+ block = block.to(device)
+ for param in block.parameters():
+ param.requires_grad = False
+
+ blocks[module_uid] = TransformerBackend(
+ module_uid,
+ block,
+ memory_cache=memory_cache,
+ backend_dtype=None if torch_dtype == "auto" else torch_dtype,
+ args_schema=(
+ BatchTensorDescriptor(
+ 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+ ),
+ ),
+ kwargs_schema={},
+ outputs_schema=(
+ BatchTensorDescriptor(
+ 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+ ),
+ ),
+ min_batch_size=min_batch_size,
+ max_batch_size=max_batch_size,
+ )
+
+ return cls(
+ dht,
+ blocks,
+ throughput=throughput,
+ num_connection_handlers=num_handlers,
+ inference_max_length=inference_max_length,
+ device=device,
+ stats_report_interval=stats_report_interval,
+ update_period=update_period,
+ expiration=expiration,
+ prefetch_batches=prefetch_batches,
+ sender_threads=sender_threads,
+ start=start,
+ )
+
+ def run_in_background(self, await_ready=True, timeout=None):
+ """
+ Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
+ is ready to process incoming requests or for :timeout: seconds max.
+ """
+ self.start()
+ if await_ready and not self.ready.wait(timeout=timeout):
+ raise TimeoutError("ModuleContainer didn't notify .ready in {timeout} seconds")
+
+ @property
+ def ready(self) -> mp.synchronize.Event:
+ """
+ An event (multiprocessing.Event) that is set when the container is ready to process requests.
+
+ Example
+ =======
+ >>> container.start()
+ >>> container.ready.wait(timeout=10)
+ >>> print("Container ready" if container.ready.is_set() else "Container didn't start in 10 seconds")
+ """
+ return self.runtime.ready # mp.Event that is true if self is ready to process batches
+
+ def shutdown(self):
+ """
+ Gracefully terminate the container, process-safe.
+ Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
+ If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
+ """
+ if self.module_backends:
+ self.dht_handler_thread.stop.set()
+ self.dht_handler_thread.join()
+
+ declare_active_modules(
+ self.dht,
+ self.module_backends.keys(),
+ expiration_time=get_dht_time() + self.expiration,
+ state=ServerState.OFFLINE,
+ throughput=self.throughput,
+ )
+ logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
+
+ self.ready.clear()
+
+ for handler in self.conn_handlers:
+ handler.shutdown()
+ logger.debug("Connection handlers terminated")
+
+ if self.checkpoint_saver is not None:
+ self.checkpoint_saver.stop.set()
+ self.checkpoint_saver.join()
+
+ logger.debug(f"Shutting down pools")
+ for pool in self.runtime.pools:
+ if pool.is_alive():
+ pool.shutdown()
+
+ logger.debug(f"Shutting down runtime")
+ self.runtime.shutdown()
+
+ logger.info("Module container shut down succesfully")
+
+
+class ModuleAnnouncerThread(threading.Thread):
+ """Periodically announces that this container hosts the specified modules, visible to all DHT peers"""
+
+ def __init__(
+ self,
+ module_backends: Dict[str, TransformerBackend],
+ dht: DHT,
+ *,
+ throughput: float,
+ update_period: float = 30,
+ expiration: float,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.module_backends = module_backends
+ self.dht = dht
+ self.throughput = throughput
+ self.update_period = update_period
+ self.expiration = expiration
+ self.stop = threading.Event()
+
+ def run(self) -> None:
+ while True:
+ declare_active_modules(
+ self.dht,
+ self.module_backends.keys(),
+ expiration_time=get_dht_time() + self.expiration,
+ state=ServerState.ONLINE,
+ throughput=self.throughput,
+ )
+ if self.stop.wait(self.update_period):
+ break
diff --git a/petals/src/server/task_pool.py b/petals/src/server/task_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..672248f942eba3e24d729db560d93c3d681c3f04
--- /dev/null
+++ b/petals/src/server/task_pool.py
@@ -0,0 +1,178 @@
+import ctypes
+import multiprocessing as mp
+import threading
+import time
+from dataclasses import dataclass, field
+from queue import PriorityQueue
+from typing import Any, Generator, List, Optional, Sequence, Tuple
+
+import torch
+from hivemind import MPFuture, get_logger, use_hivemind_log_handler
+from hivemind.moe.server.task_pool import TaskPoolBase
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+@dataclass(order=True, frozen=True)
+class Task:
+ priority: float
+ time_submitted: float
+ future: MPFuture = field(compare=False)
+ args: Sequence[torch.Tensor] = field(compare=False)
+
+ @property
+ def uid(self) -> int:
+ return self.future._uid
+
+
+class PrioritizedTaskPool(TaskPoolBase):
+ """
+ Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
+ returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
+ A single PrioritizedTaskPool services a specific function (e.g. layer1.forward, layer2.forward or layer1.backward)
+
+ :note: unlike hivemind.moe TaskPool, this pool does *not* combine incoming requests into batches.
+ This would require grouping requests of different length.
+
+ :param process_func: function to be applied to every formed batch; called by Runtime
+ Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
+ :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
+ Measured in the total number of tokens (i.e. batch size * sequence length)
+
+ :param name: pool name, used for logging
+ :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
+ :param start: if True, start automatically at the end of __init__
+ """
+
+ def __init__(
+ self,
+ process_func: callable,
+ max_batch_size: int,
+ name: str,
+ min_batch_size=1,
+ daemon=True,
+ start=False,
+ ):
+ super().__init__(process_func, daemon=daemon, name=name)
+ self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
+
+ self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
+ self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
+
+ self._prioritizer_thread = threading.Thread(
+ name=self.name + "_prioritizer",
+ target=self._prioritize_tasks,
+ args=[self.submitted_tasks, self._ordered_tasks],
+ daemon=True,
+ )
+ self._dispatched_tasks = {}
+ self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
+ self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
+ self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
+
+ self._stop = mp.Event()
+ if start:
+ self.start()
+
+ @staticmethod
+ def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
+ """Read tasks from incoming queue and put them into a local priority queue"""
+ while True:
+ task = submitted_tasks.get()
+ if task is None:
+ logger.debug("Shutting down prioritizer thread")
+ break
+
+ ordered_tasks.put(task, block=True)
+
+ def start(self):
+ assert not self.is_alive() and not self._prioritizer_thread.is_alive()
+ self._prioritizer_thread.start()
+ super().start()
+
+ def shutdown(self, timeout: float = 3):
+ self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread
+ self._stop.set()
+
+ self.join(timeout)
+ if self.is_alive():
+ logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
+ self.terminate()
+
+ def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
+ """Add task to this pool's queue, return Future for its output"""
+ task = Task(priority, time.monotonic(), MPFuture(), args)
+ if self.get_task_size(task) > self.max_batch_size:
+ exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
+ task.future.set_exception(exc)
+ else:
+ self.submitted_tasks.put(task)
+ self.batch_sender.send(None) # use this pipe to count the number of unfinished batches
+ if (task.priority, task.time_submitted) < self.priority:
+ self.priority = (task.priority, task.time_submitted)
+ return task.future
+
+ def get_task_size(self, task: Task) -> int:
+ """compute task processing complexity; defaults to the total number of tokens"""
+ if task.args and task.args[0].ndim >= 2:
+ return task.args[0].shape[0] * task.args[0].shape[1]
+ return 1
+
+ def load_batch_to_runtime(
+ self, timeout: Optional[float] = None, device: Optional[torch.device] = None
+ ) -> Tuple[Any, List[torch.Tensor]]:
+ """receive next batch of arrays"""
+ task = self._ordered_tasks.get(block=True, timeout=timeout)
+ batch_inputs = [
+ tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args
+ ]
+ self._dispatched_tasks[task.uid] = task
+ self.batch_receiver.recv() # reduce the number of active batches
+ if not self._ordered_tasks.empty():
+ first_remaining_task: Task = self._ordered_tasks.queue[0]
+ self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)
+ return task.uid, batch_inputs
+
+ def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
+ """send results for a processed batch, previously loaded through load_batch_to_runtime"""
+ batch_outputs = [
+ tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
+ for tensor in batch_outputs
+ ]
+
+ task = self._dispatched_tasks.pop(uid, None)
+ if task is None:
+ logger.error(
+ f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result"
+ )
+ else:
+ task.future.set_result(batch_outputs)
+
+ def send_exception_from_runtime(self, uid: int, exception: BaseException):
+ task = self._dispatched_tasks.pop(uid, None)
+ if task is None:
+ logger.error(
+ f"Internal error: task task with index {uid} is missing from the dictionary; "
+ f"Could not set exception {exception}"
+ )
+ else:
+ task.future.set_exception(exception)
+
+ def run(self, *args, **kwargs):
+ self._stop.wait()
+
+ @property
+ def empty(self):
+ return not self.batch_receiver.poll()
+
+ @property
+ def priority(self) -> Tuple[float, float]:
+ """The priority of this pool equals the (priority, timestamp) of the most important task in it."""
+ return float(self._priority.value), float(self._oldest_undispatched_timestamp.value)
+
+ @priority.setter
+ def priority(self, item: Tuple[float, float]):
+ assert len(item) == 2
+ self._priority.value = float(item[0])
+ self._oldest_undispatched_timestamp.value = float(item[1])
diff --git a/petals/src/server/task_prioritizer.py b/petals/src/server/task_prioritizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e3b886cf802f5fe6ef6a7a61df9cd36bfaaebf0
--- /dev/null
+++ b/petals/src/server/task_prioritizer.py
@@ -0,0 +1,20 @@
+from abc import ABC, abstractmethod
+
+import torch
+from hivemind.moe.server.task_pool import Task
+
+
+class TaskPrioritizerBase(ABC):
+ """Abstract class for TaskPrioritizer whose reponsibility is to evaluate task priority"""
+
+ @abstractmethod
+ def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
+ """Evaluates task value by the amout of points given, task input and additional kwargs. Lower priority is better"""
+ pass
+
+
+class DummyTaskPrioritizer(TaskPrioritizerBase):
+ """Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
+
+ def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
+ return 0.0
diff --git a/petals/src/server/throughput.py b/petals/src/server/throughput.py
new file mode 100644
index 0000000000000000000000000000000000000000..f14e9366c52ab711f1a835b7c3018d06e729db71
--- /dev/null
+++ b/petals/src/server/throughput.py
@@ -0,0 +1,127 @@
+import fcntl
+import json
+import os
+import subprocess
+import tempfile
+import time
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from typing import Dict, Union
+
+import torch
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+from src import project_name
+from src.bloom.block import BloomBlock
+from src.bloom.model import BloomConfig
+from src.bloom.ops import build_alibi_tensor
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", project_name, "throughput.json")
+DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, "throughput.lock")
+
+SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], "cli", "speed_test.py")
+
+
+@dataclass
+class ThroughputInfo:
+ network_rps: float
+ device_rps: Dict[str, float]
+
+
+def get_host_throughput(
+ device: Union[str, torch.device],
+ force_eval: bool = False,
+ cache_path: str = DEFAULT_CACHE_PATH,
+ lock_path: str = DEFAULT_LOCK_PATH,
+) -> float:
+ # We only keep the device type, assuming that the throughput is similar among all host's GPUs
+ device = torch.device(device).type
+
+ # We use the system-wide lock since only one process at a time can measure the host throughput
+ os.makedirs(lock_path.parent, exist_ok=True)
+ with open(lock_path, "wb") as lock_fd:
+ logger.info("Loading throughput info")
+ fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
+ # The OS will release the lock when lock_fd is closed or the process is killed
+
+ info = None
+ try:
+ if not force_eval and os.path.exists(cache_path):
+ with open(cache_path) as cache_fd:
+ info = ThroughputInfo(**json.load(cache_fd))
+ if device not in info.device_rps:
+ force_eval = True
+ except Exception:
+ logger.exception(f"Failed to read throughput info from {cache_path}")
+ force_eval = True
+
+ if force_eval or info is None:
+ info = measure_throughput_info()
+ try:
+ os.makedirs(cache_path.parent, exist_ok=True)
+ with open(cache_path, "w") as cache_fd:
+ json.dump(asdict(info), cache_fd)
+ except Exception:
+ logger.exception(f"Failed to save throughput info in {cache_path}")
+
+ throughput = min(info.network_rps, info.device_rps[device])
+ return throughput
+
+
+def measure_throughput_info() -> ThroughputInfo:
+ logger.info(
+ "Measuring network, CPU, and GPU throughput. " "This takes about a minute and will be cached for future runs"
+ )
+
+ # We measure throughput in "(inference) requests per second" (RPS) using a fixed model
+ config = BloomConfig.from_pretrained("bigscience/test-bloomd-6b3")
+
+ network_rps = measure_network_rps(config)
+
+ device_rps = {"cpu": measure_device_rps("cpu", config)}
+ if torch.cuda.is_available():
+ device_rps["cuda"] = measure_device_rps("cuda", config)
+
+ return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
+
+
+def measure_network_rps(config: BloomConfig) -> float:
+ proc = subprocess.run([SPEED_TEST_PATH, "--json"], capture_output=True)
+ if proc.returncode != 0:
+ raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
+ network_info = json.loads(proc.stdout)
+
+ bits_per_request = config.hidden_size * 32
+ network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
+
+ logger.info(
+ f"Network throughput: "
+ f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
+ f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
+ f"{network_rps:.2f} RPS"
+ )
+ return network_rps
+
+
+def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n_steps: int = 500) -> float:
+ with torch.inference_mode():
+ block = BloomBlock(config, layer_index).to(device)
+ cache = None
+ elapsed = 0
+ for i in range(n_steps):
+ dummy_input = torch.randn(1, 1, config.hidden_size, device=device)
+ alibi = build_alibi_tensor(i + 1, config.num_attention_heads, dtype=torch.float32, device=device)
+
+ start_time = time.perf_counter()
+ _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
+ elapsed += time.perf_counter() - start_time
+ device_rps = n_steps / elapsed
+
+ device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == "cuda" else "CPU"
+ logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
+
+ return device_rps
diff --git a/petals/src/utils/__init__.py b/petals/src/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/petals/src/utils/convert_8bit.py b/petals/src/utils/convert_8bit.py
new file mode 100644
index 0000000000000000000000000000000000000000..534d6292bd97ad98b3cffcbe509a99bc1651775f
--- /dev/null
+++ b/petals/src/utils/convert_8bit.py
@@ -0,0 +1,41 @@
+import os
+
+import bitsandbytes as bnb
+import torch
+
+PETALS_8BIT_BACKWARD = bool(int(os.environ.get("PETALS_8BIT_BACKWARD", 0)))
+
+
+def replace_8bit_linear(model, threshold=6.0):
+ """
+ A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
+ library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
+ 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
+ version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
+ bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
+ The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
+ be kept as a `torch.nn.Linear` module.
+ Parameters:
+ model (`torch.nn.Module`):
+ Input model or `torch.nn.Module` as the function is run recursively.
+ threshold (`float`, *optional*):
+ `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
+ `6.0` as described by the paper.
+ """
+ for n, module in model.named_children():
+ if len(list(module.children())) > 0:
+ replace_8bit_linear(module, threshold)
+
+ if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
+ model._modules[n] = bnb.nn.Linear8bitLt(
+ module.in_features,
+ module.out_features,
+ module.bias is not None,
+ has_fp16_weights=False,
+ threshold=threshold,
+ memory_efficient_backward=PETALS_8BIT_BACKWARD,
+ )
+ model._modules[n].weight = bnb.nn.Int8Params(
+ module.weight.data, requires_grad=False, has_fp16_weights=False
+ ).to(module.weight.dtype)
+ return model
diff --git a/petals/src/utils/generation_algorithms.py b/petals/src/utils/generation_algorithms.py
new file mode 100644
index 0000000000000000000000000000000000000000..8507a49a60f281882dc8b87b73183a1cb92d5528
--- /dev/null
+++ b/petals/src/utils/generation_algorithms.py
@@ -0,0 +1,78 @@
+from abc import ABC
+from typing import Tuple
+
+import torch
+
+TokenIds = torch.Tensor
+HypoIds = torch.Tensor
+
+
+class DecodingAlgorithm(ABC):
+ """
+ An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
+ """
+
+ def __init__(self) -> None:
+ pass
+
+ def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+ """
+ :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
+ :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size)
+ """
+ pass
+
+
+class GreedyAlgorithm(DecodingAlgorithm):
+ """
+ The simpliest algorithm for decoding. It selects the most probable token.
+ """
+
+ def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+ """
+ Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
+ """
+ return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
+
+
+class SamplingAlgorithm(DecodingAlgorithm):
+ def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+ """
+ :param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
+ :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
+ :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size).
+ """
+ logits[indices_to_remove] = -float("Inf")
+ probs = torch.softmax(logits / self.temperature, -1)
+ return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
+
+
+class TopKAlgorithm(SamplingAlgorithm):
+ # TODO: Add NumHypos, maxBatchSize
+ def __init__(self, top_k: int, temperature: float = 1.0) -> None:
+ self.top_k = top_k
+ self.temperature = temperature
+
+ def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+ indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
+ return self.sample(logits, indices_to_remove)
+
+
+class NucleusAlgorithm(SamplingAlgorithm):
+ def __init__(self, top_p: float, temperature: float = 1.0) -> None:
+ self.top_p = top_p
+ self.temperature = temperature
+
+ def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+ probs = torch.softmax(sorted_logits / self.temperature, -1)
+ cumulative_probs = torch.cumsum(probs, dim=-1)
+ sorted_indices_to_remove = cumulative_probs > self.top_p
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = False
+ indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
+ indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
+ return self.sample(logits, indices_to_remove)
+
+
+# TODO: In generate function we need to check usage of top_k or sampling algorithm
diff --git a/petals/src/utils/generation_constraints.py b/petals/src/utils/generation_constraints.py
new file mode 100644
index 0000000000000000000000000000000000000000..72c526fde38189075ea7cc5e15ce81c30c074a37
--- /dev/null
+++ b/petals/src/utils/generation_constraints.py
@@ -0,0 +1,84 @@
+from abc import ABC
+
+import torch
+
+
+class ABCBloomConstraint(ABC):
+ """
+ Base class of all kind of decoding constraints. It can be used to implement a new constraint.
+ """
+
+ def __init__(self) -> None:
+ pass
+
+ def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
+ """
+ This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
+ :param tokens_id: The token id of the last choosen token.
+ :param logits: The logits from the Bloom model.
+ :param hypo_ids: The hypothesis ids of the last tokens.
+ """
+ pass
+
+
+class MaxNewTokensConstraint(ABCBloomConstraint):
+ """
+ Constraint that forbids to generate more than max_new_tokens tokens after the prefix.
+
+ Args:
+ prefix: The prefix of the sequence.
+ max_new_tokens: The maximum number of tokens that can be generated after the prefix.
+ eos_token_id: The id of the end of sentence token.
+ pad_token_id: The id of the padding token.
+ min_logits: The minimum logits that can be generated. Default: -1e6.
+ """
+
+ def __init__(
+ self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8
+ ) -> None:
+ self.max_new_tokens = max_new_tokens
+ self.current_generated_tokens = None
+ self.eos_token_id = eos_token_id
+ self.min_logits = min_logits
+
+ max_pad_size = (prefix == pad_token_id).sum(1).unsqueeze(1).max()
+ self.current_generated_tokens = (prefix == pad_token_id).sum(1).unsqueeze(1) - max_pad_size
+
+ def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
+ if tokens_id is not None:
+ self.current_generated_tokens += 1
+
+ mask = self.current_generated_tokens >= self.max_new_tokens
+ logits += self.min_logits * mask
+ logits[mask[:, 0], self.eos_token_id] = 0
+ return logits
+
+
+class EosConstraint(ABCBloomConstraint):
+ """
+ This constrained repeats EOS token if it was generated on the previous step.
+ Args:
+ prefix: The prefix of the sequence.
+ eos_token_id: The id of the end of sentence token.
+ pad_token_id: The id of the padding token.
+ min_logits: The minimum logits that can be generated. Default: -1e6.
+ """
+
+ def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
+ self.eos_token_id = eos_token_id
+ self.min_logits = min_logits
+ self.past_tokens = None
+
+ self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
+
+ def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
+ if self.past_tokens is not None:
+ mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
+ logits += self.min_logits * mask
+ logits[mask[:, 0], self.eos_token_id] = 0
+
+ if tokens_id is not None:
+ self.past_tokens = tokens_id
+ self.wait_until_starting -= 1
+
+ return logits
diff --git a/petals/src/utils/misc.py b/petals/src/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f67202304e3d74766801f49a4784730ea3840ea
--- /dev/null
+++ b/petals/src/utils/misc.py
@@ -0,0 +1,7 @@
+import torch
+
+DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters
+
+
+def is_dummy(tensor: torch.Tensor):
+ return tensor.numel() == 0
diff --git a/petals/tests/conftest.py b/petals/tests/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..57287c3b0ffbec41528dfbffb3ea129a56af254f
--- /dev/null
+++ b/petals/tests/conftest.py
@@ -0,0 +1,51 @@
+import asyncio
+import gc
+from contextlib import suppress
+
+import psutil
+import pytest
+from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.mpfuture import MPFuture
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
+
+
+@pytest.fixture
+def event_loop():
+ """
+ This overrides the ``event_loop`` fixture from pytest-asyncio
+ (e.g. to make it compatible with ``asyncio.subprocess``).
+
+ This fixture is identical to the original one but does not call ``loop.close()`` in the end.
+ Indeed, at this point, the loop is already stopped (i.e. next tests are free to create new loops).
+ However, finalizers of objects created in the current test may reference the current loop and fail if it is closed.
+ For example, this happens while using ``asyncio.subprocess`` (the ``asyncio.subprocess.Process`` finalizer
+ fails if the loop is closed, but works if the loop is only stopped).
+ """
+
+ yield asyncio.get_event_loop()
+
+
+@pytest.fixture(autouse=True, scope="session")
+def cleanup_children():
+ yield
+
+ with RSAPrivateKey._process_wide_key_lock:
+ RSAPrivateKey._process_wide_key = None
+
+ gc.collect() # Call .__del__() for removed objects
+
+ children = psutil.Process().children(recursive=True)
+ if children:
+ logger.info(f"Cleaning up {len(children)} leftover child processes")
+ for child in children:
+ with suppress(psutil.NoSuchProcess):
+ child.terminate()
+ psutil.wait_procs(children, timeout=1)
+ for child in children:
+ with suppress(psutil.NoSuchProcess):
+ child.kill()
+
+ MPFuture.reset_backend()
diff --git a/petals/tests/scripts/remove_old_models.py b/petals/tests/scripts/remove_old_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac3ba408f0c61dda6657924d91ef4e708c83ba06
--- /dev/null
+++ b/petals/tests/scripts/remove_old_models.py
@@ -0,0 +1,25 @@
+import argparse
+from datetime import datetime
+
+from huggingface_hub import delete_repo, list_models
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Remove old testing models from HF hub")
+ parser.add_argument("--author", type=str, default="bloom-testing", help="auth token for from_pretrained")
+ parser.add_argument("--seconds_since_last_updated", type=int, default=7 * 24 * 60 * 60)
+ parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
+ parser.add_argument("--dry_run", action="store_true")
+
+ args = parser.parse_args()
+
+ for model in list_models(author=args.author, full=True):
+ last_modified = datetime.strptime(model.lastModified, "%Y-%m-%dT%H:%M:%S.%fZ")
+
+ if model.modelId.endswith("-main") or "/test-" not in model.modelId:
+ continue # remove only test models
+
+ if (datetime.now() - last_modified).total_seconds() > args.seconds_since_last_updated:
+ if args.dry_run:
+ print(f"{model.modelId} can be deleted")
+ else:
+ delete_repo(token=args.use_auth_token, name=model.modelId, organization=args.author)
diff --git a/petals/tests/test.id b/petals/tests/test.id
new file mode 100644
index 0000000000000000000000000000000000000000..2806712526e5bd6f365c8ddd40d502ac1842de43
Binary files /dev/null and b/petals/tests/test.id differ
diff --git a/petals/tests/test_block_exact_match.py b/petals/tests/test_block_exact_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..fad84ae74bfa0643083a0767e189156dc5bba174
--- /dev/null
+++ b/petals/tests/test_block_exact_match.py
@@ -0,0 +1,46 @@
+import random
+
+import hivemind
+import pytest
+import torch
+import transformers
+from hivemind import P2PHandlerError
+from test_utils import *
+
+import src
+from src import DistributedBloomConfig
+from src.bloom.from_pretrained import load_pretrained_block
+from src.client.remote_sequential import RemoteTransformerBlock
+from src.data_structures import UID_DELIMITER
+from src.dht_utils import get_remote_module
+
+
+@pytest.mark.forked
+def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
+ dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+ config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+
+ for block_index in random.sample(range(config.n_layer), 3):
+ remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config)
+ assert isinstance(remote_block, RemoteTransformerBlock)
+
+ inputs = torch.randn(1, 8, config.hidden_size)
+ outputs_forward = remote_block(inputs)
+
+ outputs_inference = []
+ with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
+ for i in range(inputs.shape[1]):
+ outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+
+ # test that max length is respected
+ with pytest.raises(P2PHandlerError) as exc_info:
+ sess.step(inputs[:, -1:, :])
+ assert "Maximum length exceeded" in repr(exc_info.value)
+
+ outputs_inference = torch.cat(outputs_inference, dim=1)
+
+ ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
+ (outputs_local,) = ref_block(inputs)
+
+ assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
+ assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
diff --git a/petals/tests/test_chained_calls.py b/petals/tests/test_chained_calls.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cf6d440eaa0ad40cb99a86680e3e233c6b92e58
--- /dev/null
+++ b/petals/tests/test_chained_calls.py
@@ -0,0 +1,79 @@
+######
+# Warning:torch this test is a work in progress. It will be modified soon.
+# - if you want more stable tests, see test_block_exact_match
+# - if you want to figure out chained inference, ask yozh
+
+
+import hivemind
+import pytest
+import torch
+from test_utils import *
+
+import src
+from src.bloom.from_pretrained import load_pretrained_block
+from src.client.remote_sequential import RemoteSequential
+from src.dht_utils import get_remote_sequence
+
+
+@pytest.mark.forked
+def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
+ dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+ config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+ remote_blocks = get_remote_sequence(dht, 3, 6, config)
+ assert isinstance(remote_blocks, RemoteSequential)
+
+ ref_blocks = [
+ load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
+ load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
+ load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),
+ ]
+ inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
+ outputs_rpc = remote_blocks.forward(inputs)
+ outputs_rpc.sum().backward()
+ grads_rpc = inputs.grad
+
+ inputs.grad = None
+ hidden_states = inputs
+ for ref_block in ref_blocks:
+ hidden_states = ref_block.forward(hidden_states)[0]
+ outputs_ref = hidden_states
+ outputs_ref.sum().backward()
+ grads_ref = inputs.grad
+
+ assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)
+ assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
+
+
+@pytest.mark.forked
+def test_chained_inference_exact_match(atol_inference=1e-4):
+ dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+ config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+ remote_blocks = get_remote_sequence(dht, 3, 5, config)
+ assert isinstance(remote_blocks, RemoteSequential)
+
+ inputs = torch.randn(1, 8, config.hidden_size)
+
+ outputs_inference = []
+ with remote_blocks.inference_session(max_length=inputs.shape[1]) as sess:
+ for i in range(inputs.shape[1]):
+ outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+ outputs_inference = torch.cat(outputs_inference, dim=1)
+
+ ref_blocks = [
+ load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
+ load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
+ ]
+ outputs_ref = []
+ caches = [None, None]
+ for i in range(inputs.shape[1]):
+ new_caches = []
+ hidden_states = inputs[:, i : i + 1, :]
+ for ref_block, cache in zip(ref_blocks, caches):
+ with torch.no_grad():
+ hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
+ new_caches.append(new_cache)
+
+ outputs_ref.append(hidden_states)
+ caches = new_caches
+ outputs_ref = torch.cat(outputs_ref, dim=1)
+ assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)
diff --git a/petals/tests/test_full_model.py b/petals/tests/test_full_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0ce824bc8967277f611dc25b30cc8e9da47769f
--- /dev/null
+++ b/petals/tests/test_full_model.py
@@ -0,0 +1,91 @@
+import pytest
+import torch
+import transformers
+from hivemind import get_logger, use_hivemind_log_handler
+from test_utils import *
+
+from src.bloom.model import BloomForCausalLM
+from src.client.remote_model import DistributedBloomForCausalLM
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+@pytest.mark.forked
+def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+ model = DistributedBloomForCausalLM.from_pretrained(
+ MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+ )
+ config = model.config
+ assert isinstance(model, DistributedBloomForCausalLM)
+ assert len(model.transformer.h) == model.config.n_layer
+
+ test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+
+ with torch.inference_mode():
+ parallel_outputs = model.forward(test_inputs).logits
+ assert torch.all(torch.isfinite(parallel_outputs))
+ logger.info("Forward outputs are finite")
+
+ embs = model.transformer.word_embeddings(test_inputs)
+ embs = model.transformer.word_embeddings_layernorm(embs)
+ recurrent_outputs = []
+ with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
+ for t in range(embs.shape[1]):
+ recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
+ recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
+ recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
+ recurrent_outputs = model.lm_head(recurrent_outputs)
+ assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
+ logger.info("Inference is consistent with forward")
+
+ del model, embs, recurrent_outputs
+
+ if REF_NAME:
+ ref_model = transformers.BloomForCausalLM.from_pretrained(
+ REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+ )
+ if config.vocab_size < ref_model.config.vocab_size:
+ ref_model.resize_token_embeddings(config.vocab_size)
+ logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
+
+ dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
+ # note: this creates a dummy mask to make the test compatible with older transformer versions
+ # prior to https://github.com/huggingface/transformers/pull/17837
+ ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
+ assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
+ logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
+ del ref_model, ref_outputs, dummy_mask
+ else:
+ logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
+ assert False
+
+
+@pytest.mark.forked
+def test_greedy_generation(max_new_tokens=4):
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+ model = DistributedBloomForCausalLM.from_pretrained(
+ MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+ )
+ inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+ remote_outputs = model.generate(
+ inputs,
+ max_new_tokens=max_new_tokens,
+ )
+ hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
+ assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
+
+ inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
+ "input_ids"
+ ]
+ remote_outputs_batch = model.generate(
+ inputs_batch,
+ max_new_tokens=max_new_tokens,
+ )
+ hf_outputs_batch = BloomForCausalLM.greedy_search(
+ model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
+ )
+ assert torch.allclose(
+ remote_outputs_batch, hf_outputs_batch
+ ), "Greedy search are not identical to HF in multibatch mode"
diff --git a/petals/tests/test_priority_pool.py b/petals/tests/test_priority_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..21dd74effb18653347ce745ee37e444200f3a52c
--- /dev/null
+++ b/petals/tests/test_priority_pool.py
@@ -0,0 +1,71 @@
+import multiprocessing as mp
+import time
+
+import pytest
+import torch
+
+from src.server.runtime import Runtime
+from src.server.task_pool import PrioritizedTaskPool
+
+
+@pytest.mark.forked
+def test_priority_pools():
+ outputs_queue = mp.SimpleQueue()
+ results_valid = mp.Event()
+
+ def dummy_pool_func(x):
+ time.sleep(0.1)
+ y = x**2
+ outputs_queue.put((x, y))
+ return (y,)
+
+ class DummyBackend:
+ def __init__(self, pools):
+ self.pools = pools
+
+ def get_pools(self):
+ return self.pools
+
+ pools = (
+ PrioritizedTaskPool(dummy_pool_func, name="A", max_batch_size=1),
+ PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
+ )
+
+ runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
+ runtime.start()
+
+ def process_tasks():
+ futures = []
+ futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
+ futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
+ time.sleep(0.01)
+ futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
+ futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
+ futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
+ futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
+ futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
+ futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
+ futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
+ for i, f in enumerate(futures):
+ assert f.result()[0].item() == i**2
+ results_valid.set()
+
+ proc = mp.Process(target=process_tasks)
+ proc.start()
+ proc.join()
+ assert results_valid.is_set()
+
+ ordered_outputs = []
+ while not outputs_queue.empty():
+ ordered_outputs.append(outputs_queue.get()[0].item())
+
+ assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7]
+ # 0 - first batch is loaded immediately, before everything else
+ # 5 - highest priority task overall
+ # 1 - first of several tasks with equal lowest priority (1)
+ # 2 - second earliest task with priority 1, fetched from pool B
+ # 6 - third earliest task with priority 1, fetched from pool A again
+ # 8 - last priority-1 task, pool B
+ # 3 - task with priority 2 from pool A
+ # 4 - task with priority 10 from pool A
+ # 7 - task with priority 11 from pool B
diff --git a/petals/tests/test_remote_sequential.py b/petals/tests/test_remote_sequential.py
new file mode 100644
index 0000000000000000000000000000000000000000..678ec01ee18f38a7c1dfaa01984f5e7ad0d39264
--- /dev/null
+++ b/petals/tests/test_remote_sequential.py
@@ -0,0 +1,89 @@
+import pytest
+import torch
+from hivemind import DHT, get_logger, use_hivemind_log_handler
+from test_utils import *
+
+from src import RemoteSequential
+from src.bloom.from_pretrained import load_pretrained_block
+from src.client.remote_model import DistributedBloomConfig
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+@pytest.mark.forked
+def test_remote_sequential():
+ config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+ dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+ test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
+ grad_proj = torch.randn(1, 5, config.hidden_size)
+
+ sequential = RemoteSequential(config, dht)
+
+ full_outputs = sequential(test_inputs)
+ (full_outputs * grad_proj).sum().backward()
+ assert test_inputs.grad is not None
+ full_grad = test_inputs.grad.clone()
+ test_inputs.grad.data.zero_()
+
+ first_half = sequential[: config.n_layer // 2]
+ second_half = sequential[config.n_layer // 2 :]
+ assert len(first_half) + len(second_half) == len(sequential)
+ assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
+ for m in sequential, first_half, second_half:
+ assert isinstance(repr(m), str)
+
+ hidden = first_half(test_inputs)
+ assert isinstance(hidden, torch.Tensor)
+ assert hidden.shape == test_inputs.shape
+ assert hidden.requires_grad
+ second_half_outputs = second_half(hidden)
+ assert torch.allclose(second_half_outputs, full_outputs)
+
+ (second_half_outputs * grad_proj).sum().backward()
+ assert torch.allclose(test_inputs.grad, full_grad)
+
+
+@pytest.mark.forked
+def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
+ config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+ dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+ remote_sequential = RemoteSequential(config, dht)
+
+ inputs = torch.randn(batch_size, seq_len, config.hidden_size)
+ output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size)
+ input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
+ intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
+
+ input_prompts = input_prompts.detach().requires_grad_(True)
+ intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
+
+ inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)
+ assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)
+
+ outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts)
+
+ (outputs * output_proj).sum().backward()
+ assert intermediate_prompts.grad is not None
+
+ input_prompts_ref = input_prompts.clone().detach().requires_grad_(True)
+ intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True)
+
+ assert input_prompts_ref.grad is None
+ assert intermediate_prompts_ref.grad is None
+
+ outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
+ for block_index in range(config.n_layer):
+ block_prompt = intermediate_prompts_ref[block_index]
+ outputs_ref[:, : block_prompt.shape[1]] += block_prompt
+
+ block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
+ (outputs_ref,) = block(outputs_ref)
+
+ assert torch.allclose(outputs_ref, outputs)
+
+ (outputs_ref * output_proj).sum().backward()
+ assert input_prompts_ref.grad is not None
+ assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
+ assert intermediate_prompts_ref.grad is not None
+ assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)
diff --git a/petals/tests/test_utils.py b/petals/tests/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee440d616d3068e5a11a8adf65dddd07786743e9
--- /dev/null
+++ b/petals/tests/test_utils.py
@@ -0,0 +1,13 @@
+import os
+
+INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
+if not INITIAL_PEERS:
+ raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
+INITIAL_PEERS = INITIAL_PEERS.split()
+
+
+MODEL_NAME = os.environ.get("MODEL_NAME")
+if not MODEL_NAME:
+ raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
+
+REF_NAME = os.environ.get("REF_NAME")
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0a2a30fccf93c47297947ed2ed24d119d51d6df7
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,9 @@
+torch>=1.12
+bitsandbytes>=0.35.4
+accelerate>=0.10.0
+huggingface-hub>=0.7.0
+transformers==4.21.3
+protobuf>=3.12.2,<4.0.0
+hivemind>=1.1.2
+humanfriendly
+gradio