Upload 2 files
Browse files- hmm_model_large.pkl.bz2 +2 -2
- py2hz.py +34 -31
hmm_model_large.pkl.bz2
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b311984d38a755825bb43db500c6f0a0b3d22ecafe10483aa04b27ecf1d6e53d
|
3 |
+
size 3814973
|
py2hz.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# _*_ coding:utf-8 _*_
|
2 |
"""
|
3 |
-
@Version : 1.
|
4 |
-
@Time : 2024年12月
|
5 |
@Author : DuYu (@duyu09, [email protected])
|
6 |
@File : py2hz.py
|
7 |
@Describe : 基于隐马尔可夫模型(HMM)的拼音转汉字程序。
|
@@ -92,36 +92,28 @@ def train_hmm(sentences, pinyins, hanzi2id, pinyin2id):
|
|
92 |
zero_emit_rows = (emit_sums == 0).flatten()
|
93 |
emit_prob[zero_emit_rows, :] = 1.0 / n_observations # 均匀填充
|
94 |
emit_prob /= emit_prob.sum(axis=1, keepdims=True)
|
95 |
-
|
96 |
model.startprob_ = start_prob
|
97 |
model.transmat_ = trans_prob
|
98 |
model.emissionprob_ = emit_prob
|
99 |
-
|
100 |
return model
|
101 |
|
102 |
# 4. 保存和加载模型
|
103 |
-
def save_model(model, filepath):
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
def
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
obs_seq[t, 0] = 1 # 未知拼音默认处理
|
119 |
-
|
120 |
-
# 解码预测
|
121 |
-
model.n_trials = 3 # 运行3次
|
122 |
-
log_prob, state_seq = model.decode(obs_seq, algorithm='viterbi')
|
123 |
-
result = ''.join([id2hanzi[s] for s in state_seq])
|
124 |
-
return result
|
125 |
|
126 |
|
127 |
def train(dataset_path='train.csv', model_path='hmm_model.pkl.bz2'):
|
@@ -135,12 +127,23 @@ def train(dataset_path='train.csv', model_path='hmm_model.pkl.bz2'):
|
|
135 |
save_model(model, model_path) # 保存模型
|
136 |
|
137 |
|
138 |
-
def pred(model_path='hmm_model.pkl.bz2', pinyin_str='ce4 shi4'):
|
139 |
-
model = load_model(model_path)
|
140 |
pinyin_list = pinyin_str.split()
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
print('预测结果:', result)
|
143 |
|
144 |
if __name__ == '__main__':
|
145 |
-
# train(dataset_path='
|
146 |
-
pred(model_path='
|
|
|
1 |
# _*_ coding:utf-8 _*_
|
2 |
"""
|
3 |
+
@Version : 1.2.0
|
4 |
+
@Time : 2024年12月29日
|
5 |
@Author : DuYu (@duyu09, [email protected])
|
6 |
@File : py2hz.py
|
7 |
@Describe : 基于隐马尔可夫模型(HMM)的拼音转汉字程序。
|
|
|
92 |
zero_emit_rows = (emit_sums == 0).flatten()
|
93 |
emit_prob[zero_emit_rows, :] = 1.0 / n_observations # 均匀填充
|
94 |
emit_prob /= emit_prob.sum(axis=1, keepdims=True)
|
95 |
+
|
96 |
model.startprob_ = start_prob
|
97 |
model.transmat_ = trans_prob
|
98 |
model.emissionprob_ = emit_prob
|
|
|
99 |
return model
|
100 |
|
101 |
# 4. 保存和加载模型
|
102 |
+
def save_model(model, filepath, mode='compress'): # mode='normal'意味着不使用压缩
|
103 |
+
if mode == 'normal':
|
104 |
+
with open(filepath, 'wb') as f:
|
105 |
+
pickle.dump(model, f)
|
106 |
+
else:
|
107 |
+
with bz2.BZ2File(filepath, 'wb') as f:
|
108 |
+
pickle.dump(model, f)
|
109 |
+
|
110 |
+
def load_model(filepath, mode='compress'): # mode='normal'意味着不使用压缩
|
111 |
+
if mode == 'normal':
|
112 |
+
with open(filepath, 'rb') as f:
|
113 |
+
return pickle.load(f)
|
114 |
+
else:
|
115 |
+
with bz2.BZ2File(filepath, 'rb') as f:
|
116 |
+
return pickle.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
|
119 |
def train(dataset_path='train.csv', model_path='hmm_model.pkl.bz2'):
|
|
|
127 |
save_model(model, model_path) # 保存模型
|
128 |
|
129 |
|
130 |
+
def pred(model_path='hmm_model.pkl.bz2', pinyin_str='ce4 shi4', n_trials=3):
|
131 |
+
model = load_model(model_path)
|
132 |
pinyin_list = pinyin_str.split()
|
133 |
+
pinyin2id, id2hanzi = model.pinyin2id, model.id2hanzi
|
134 |
+
obs_seq = np.zeros((len(pinyin_list), len(pinyin2id))) # 转换观测序列为 one-hot 格式
|
135 |
+
for t, p in enumerate(pinyin_list):
|
136 |
+
if p in pinyin2id:
|
137 |
+
obs_seq[t, pinyin2id[p]] = 1
|
138 |
+
else:
|
139 |
+
obs_seq[t, 0] = 1 # 未知拼音默认处理
|
140 |
+
|
141 |
+
# 解码预测
|
142 |
+
model.n_trials = n_trials
|
143 |
+
log_prob, state_seq = model.decode(obs_seq, algorithm=model.algorithm)
|
144 |
+
result = ''.join([id2hanzi[s] for s in state_seq])
|
145 |
print('预测结果:', result)
|
146 |
|
147 |
if __name__ == '__main__':
|
148 |
+
# train(dataset_path='train.csv', model_path='hmm_model_large.pkl.bz2')
|
149 |
+
pred(model_path='hmm_model_large.pkl.bz2', pinyin_str='hong2 yan2 bo2 ming4') # 预测结果:红颜薄命
|