''' Adapted from OpenAssistant's original xor_codec.py: https://huggingface.co/OpenAssistant/oasst-sft-6-llama-30b-xor/raw/main/xor_codec.py ''' import os import sys import gzip import numpy from pathlib import Path def xor_uncompressed(dst, src_payload, src_base, block_size=4096): fp_payload = open(src_payload, 'rb') fp_base = open(src_base, 'rb') with open(dst, 'wb') as fp: while True: buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8) buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8) padding = len(buf1) - len(buf2) if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,)) if padding < 0: buf2 = buf2[:len(buf1)] buf = numpy.bitwise_xor(buf1, buf2) fp.write(buf) if len(buf1) < block_size: break fp_payload.close() fp_base.close() def xor_encode(dst, src_payload, src_base, block_size=4096): fp_payload = open(src_payload, 'rb') fp_base = open(src_base, 'rb') with gzip.open(dst, 'wb') as fp: while True: buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8) buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8) padding = len(buf1) - len(buf2) if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,)) if padding < 0: buf2 = buf2[:len(buf1)] buf = numpy.bitwise_xor(buf1, buf2) fp.write(buf) if len(buf1) < block_size: break fp_payload.close() fp_base.close() def xor_decode(dst, src_payload, src_base, block_size=4096): fp_payload = gzip.open(src_payload, 'rb') fp_base = open(src_base, 'rb') with open(dst, 'wb') as fp: while True: buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8) buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8) padding = len(buf1) - len(buf2) if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,)) if padding < 0: buf2 = buf2[:len(buf1)] buf = numpy.bitwise_xor(buf1, buf2) fp.write(buf) if len(buf1) < block_size: break fp_payload.close() fp_base.close() def xor_dir(dst, src_payload, src_base, decode=True, compress=True): if compress: xor = xor_decode if decode else xor_encode else: xor = xor_uncompressed Path(dst).mkdir(parents=True, exist_ok=True) for path in os.listdir(src_payload): print("[*] Processing '%s'" % path) try: xor("%s/%s" % (dst, path), "%s/%s" % (src_payload, path), "%s/%s" % (src_base, path)) except Exception as e: print("Exception when processing '%s'" % path) if __name__ == "__main__": if len(sys.argv) < 4: print("Usage: xor.py [--encode] [--compress]") exit() dst = sys.argv[1] src_payload = sys.argv[2] src_base = sys.argv[3] decode = True compress = False if len(sys.argv) > 4: for arg in sys.argv[4:]: if arg == "--encode": decode = False if arg == "--compress": compress = True xor_dir(dst, src_payload, src_base, decode=decode, compress=compress)