minko186 commited on
Commit
3355824
·
verified ·
1 Parent(s): b5d332e

Update gptzero_free.py

Browse files
Files changed (1) hide show
  1. gptzero_free.py +116 -0
gptzero_free.py CHANGED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This code a slight modification of perplexity by hugging face
3
+ https://huggingface.co/docs/transformers/perplexity
4
+
5
+ Both this code and the orignal code are published under the MIT license.
6
+
7
+ by Burhan Ul tayyab and Nicholas Chua
8
+ """
9
+
10
+ import torch
11
+ import re
12
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast
13
+ from collections import OrderedDict
14
+
15
+
16
+ class GPT2PPL:
17
+ def __init__(self, device="cuda", model_id="gpt2"):
18
+ self.device = device
19
+ self.model_id = model_id
20
+ self.model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
21
+ self.tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
22
+
23
+ self.max_length = self.model.config.n_positions
24
+ self.stride = 512
25
+
26
+ def getResults(self, threshold):
27
+ if threshold < 60:
28
+ label = 0
29
+ return "The Text is generated by AI.", label
30
+ elif threshold < 80:
31
+ label = 0
32
+ return "The Text is most probably contain parts which are generated by AI. (require more text for better Judgement)", label
33
+ else:
34
+ label = 1
35
+ return "The Text is written by Human.", label
36
+
37
+ def __call__(self, sentence):
38
+ """
39
+ Takes in a sentence split by full stop
40
+ and print the perplexity of the total sentence
41
+
42
+ split the lines based on full stop and find the perplexity of each sentence and print
43
+ average perplexity
44
+
45
+ Burstiness is the max perplexity of each sentence
46
+ """
47
+ results = OrderedDict()
48
+
49
+ total_valid_char = re.findall("[a-zA-Z0-9]+", sentence)
50
+ total_valid_char = sum([len(x) for x in total_valid_char]) # finds len of all the valid characters a sentence
51
+
52
+ if total_valid_char < 100:
53
+ return {"status": "Please input more text (min 100 characters)"}, "Please input more text (min 100 characters)"
54
+
55
+ lines = re.split(r'(?<=[.?!][ \[\(])|(?<=\n)\s*',sentence)
56
+ lines = list(filter(lambda x: (x is not None) and (len(x) > 0), lines))
57
+
58
+ ppl = self.getPPL(sentence)
59
+ print(f"Perplexity {ppl}")
60
+ results["Perplexity"] = ppl
61
+
62
+ offset = ""
63
+ Perplexity_per_line = []
64
+ for i, line in enumerate(lines):
65
+ if re.search("[a-zA-Z0-9]+", line) == None:
66
+ continue
67
+ if len(offset) > 0:
68
+ line = offset + line
69
+ offset = ""
70
+ # remove the new line pr space in the first sentence if exists
71
+ if line[0] == "\n" or line[0] == " ":
72
+ line = line[1:]
73
+ if line[-1] == "\n" or line[-1] == " ":
74
+ line = line[:-1]
75
+ elif line[-1] == "[" or line[-1] == "(":
76
+ offset = line[-1]
77
+ line = line[:-1]
78
+ ppl = self.getPPL(line)
79
+ Perplexity_per_line.append(ppl)
80
+ print(f"Perplexity per line {sum(Perplexity_per_line)/len(Perplexity_per_line)}")
81
+ results["Perplexity per line"] = sum(Perplexity_per_line)/len(Perplexity_per_line)
82
+
83
+ print(f"Burstiness {max(Perplexity_per_line)}")
84
+ results["Burstiness"] = max(Perplexity_per_line)
85
+
86
+ out, label = self.getResults(results["Perplexity per line"])
87
+ results["label"] = label
88
+
89
+ return results, out
90
+
91
+ def getPPL(self,sentence):
92
+ encodings = self.tokenizer(sentence, return_tensors="pt")
93
+ seq_len = encodings.input_ids.size(1)
94
+
95
+ nlls = []
96
+ likelihoods = []
97
+ prev_end_loc = 0
98
+ for begin_loc in range(0, seq_len, self.stride):
99
+ end_loc = min(begin_loc + self.max_length, seq_len)
100
+ trg_len = end_loc - prev_end_loc
101
+ input_ids = encodings.input_ids[:, begin_loc:end_loc].to(self.device)
102
+ target_ids = input_ids.clone()
103
+ target_ids[:, :-trg_len] = -100
104
+
105
+ with torch.no_grad():
106
+ outputs = self.model(input_ids, labels=target_ids)
107
+ neg_log_likelihood = outputs.loss * trg_len
108
+ likelihoods.append(neg_log_likelihood)
109
+
110
+ nlls.append(neg_log_likelihood)
111
+
112
+ prev_end_loc = end_loc
113
+ if end_loc == seq_len:
114
+ break
115
+ ppl = int(torch.exp(torch.stack(nlls).sum() / end_loc))
116
+ return ppl