TSA / outline.py
QINGCHE's picture
fix bugs
bff547d
raw
history blame
2.91 kB
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
def find_parent(matrix, node):
parents = matrix[:, node]
max_parent = np.argmax(parents)
if parents[max_parent] > 0:
return max_parent
else:
return None
def find_tree(matrix, node, depth=1, children=[], max_depth=1, visited=set()):
result = []
parent = find_parent(matrix, node)
if parent is not None and parent not in visited:
result.append([parent, node])
for child in children:
result.append([node, child])
if depth < max_depth:
visited.add(node)
result.extend(find_tree(matrix, parent, depth + 1, visited=visited))
# 返回结果列表
return result
#
def find_prob(tree, matrix):
prob = 1
for parent, child in tree:
prob *= matrix[parent][child]
return prob
def find_forests(matrix, k):
forests = {}
for i in range(len(matrix)):
children = matrix[i]
child_list = []
for j in range(len(children)):
if children[j] > 0:
child_list.append(j)
tree = find_tree(matrix, i, children=child_list)
tree = tuple([tuple(x) for x in tree])
if tree:
prob = find_prob(tree, matrix)
if tuple(tree) in forests:
forests[tuple(tree)] += prob
else:
forests[tuple(tree)] = prob
sorted_forests = sorted(forests.items(), key=lambda x: x[1], reverse=True)
forest, prob = sorted_forests[0]
result = {}
# 遍历森林中的每个树形结构
for parent, child in forest:
if parent in result:
result[parent].append(child)
else:
result[parent] = [child]
return result, prob
def passage_outline(matrix,sentences):
result, prob = find_forests(matrix, 1)
print(result, prob)
structure = {}
for each in result.keys():
structure[each] =[sentences[i] for i in result[each]]
outline = ""
outline_list = []
for key in sorted(structure.keys()):
outline_list.append(f"主题:")
outline = outline+f"主题:\n"
for sentence in structure[key]:
outline_list.append(sentence)
outline = outline+f"- {sentence}\n"
return outline,outline_list
if __name__ == "__main__":
matrix = np.array([[0.0 ,0.02124888, 0.10647043 ,0.09494194 ,0.0689209 ],
[0.01600688 ,0.0 ,0.05879448 ,0.0331325 , 0.0155093 ],
[0.01491911 ,0.01652437, 0.0, 0.04714563, 0.04577385],
[0.01699071 ,0.0313585 , 0.040299 ,0.0 ,0.014933 ],
[0.02308992 ,0.02791895 ,0.06547201, 0.08517842 ,0.0]])
sentences = ["主题句子1", "主题句子2", "主题句子3", "主题句子4", "主题句子5"]
print(passage_outline(matrix,sentences)[0])