Duyu commited on
Commit
1800539
·
verified ·
1 Parent(s): 15077ed

Upload 2 files

Browse files
Files changed (2) hide show
  1. hmm_model_large.pkl.bz2 +2 -2
  2. py2hz.py +34 -31
hmm_model_large.pkl.bz2 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2d5214166cb693789749314a5bfc362fd317b24ecf9ae339f4fc8ea78c7c5cc6
3
- size 3808875
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b311984d38a755825bb43db500c6f0a0b3d22ecafe10483aa04b27ecf1d6e53d
3
+ size 3814973
py2hz.py CHANGED
@@ -1,7 +1,7 @@
1
  # _*_ coding:utf-8 _*_
2
  """
3
- @Version : 1.1.0
4
- @Time : 2024年12月28
5
  @Author : DuYu (@duyu09, [email protected])
6
  @File : py2hz.py
7
  @Describe : 基于隐马尔可夫模型(HMM)的拼音转汉字程序。
@@ -92,36 +92,28 @@ def train_hmm(sentences, pinyins, hanzi2id, pinyin2id):
92
  zero_emit_rows = (emit_sums == 0).flatten()
93
  emit_prob[zero_emit_rows, :] = 1.0 / n_observations # 均匀填充
94
  emit_prob /= emit_prob.sum(axis=1, keepdims=True)
95
-
96
  model.startprob_ = start_prob
97
  model.transmat_ = trans_prob
98
  model.emissionprob_ = emit_prob
99
-
100
  return model
101
 
102
  # 4. 保存和加载模型
103
- def save_model(model, filepath):
104
- with bz2.BZ2File(filepath, 'wb') as f:
105
- pickle.dump(model, f)
106
-
107
- def load_model(filepath):
108
- with bz2.BZ2File(filepath, 'rb') as f:
109
- return pickle.load(f)
110
-
111
- def predict(model, pinyin_seq, pinyin2id, id2hanzi):
112
- obs_seq = np.zeros((len(pinyin_seq), len(pinyin2id))) # 转换观测序列为 one-hot 格式
113
-
114
- for t, p in enumerate(pinyin_seq):
115
- if p in pinyin2id:
116
- obs_seq[t, pinyin2id[p]] = 1
117
- else:
118
- obs_seq[t, 0] = 1 # 未知拼音默认处理
119
-
120
- # 解码预测
121
- model.n_trials = 3 # 运行3次
122
- log_prob, state_seq = model.decode(obs_seq, algorithm='viterbi')
123
- result = ''.join([id2hanzi[s] for s in state_seq])
124
- return result
125
 
126
 
127
  def train(dataset_path='train.csv', model_path='hmm_model.pkl.bz2'):
@@ -135,12 +127,23 @@ def train(dataset_path='train.csv', model_path='hmm_model.pkl.bz2'):
135
  save_model(model, model_path) # 保存模型
136
 
137
 
138
- def pred(model_path='hmm_model.pkl.bz2', pinyin_str='ce4 shi4'):
139
- model = load_model(model_path) # 加载模型
140
  pinyin_list = pinyin_str.split()
141
- result = predict(model, pinyin_list, model.pinyin2id, model.id2hanzi)
 
 
 
 
 
 
 
 
 
 
 
142
  print('预测结果:', result)
143
 
144
  if __name__ == '__main__':
145
- # train(dataset_path='train_o.csv', model_path='hmm_model.pkl.bz2')
146
- pred(model_path='hmm_model.pkl.bz2', pinyin_str='hong2 yan2 bo2 ming4')
 
1
  # _*_ coding:utf-8 _*_
2
  """
3
+ @Version : 1.2.0
4
+ @Time : 2024年12月29
5
  @Author : DuYu (@duyu09, [email protected])
6
  @File : py2hz.py
7
  @Describe : 基于隐马尔可夫模型(HMM)的拼音转汉字程序。
 
92
  zero_emit_rows = (emit_sums == 0).flatten()
93
  emit_prob[zero_emit_rows, :] = 1.0 / n_observations # 均匀填充
94
  emit_prob /= emit_prob.sum(axis=1, keepdims=True)
95
+
96
  model.startprob_ = start_prob
97
  model.transmat_ = trans_prob
98
  model.emissionprob_ = emit_prob
 
99
  return model
100
 
101
  # 4. 保存和加载模型
102
+ def save_model(model, filepath, mode='compress'): # mode='normal'意味着不使用压缩
103
+ if mode == 'normal':
104
+ with open(filepath, 'wb') as f:
105
+ pickle.dump(model, f)
106
+ else:
107
+ with bz2.BZ2File(filepath, 'wb') as f:
108
+ pickle.dump(model, f)
109
+
110
+ def load_model(filepath, mode='compress'): # mode='normal'意味着不使用压缩
111
+ if mode == 'normal':
112
+ with open(filepath, 'rb') as f:
113
+ return pickle.load(f)
114
+ else:
115
+ with bz2.BZ2File(filepath, 'rb') as f:
116
+ return pickle.load(f)
 
 
 
 
 
 
 
117
 
118
 
119
  def train(dataset_path='train.csv', model_path='hmm_model.pkl.bz2'):
 
127
  save_model(model, model_path) # 保存模型
128
 
129
 
130
+ def pred(model_path='hmm_model.pkl.bz2', pinyin_str='ce4 shi4', n_trials=3):
131
+ model = load_model(model_path)
132
  pinyin_list = pinyin_str.split()
133
+ pinyin2id, id2hanzi = model.pinyin2id, model.id2hanzi
134
+ obs_seq = np.zeros((len(pinyin_list), len(pinyin2id))) # 转换观测序列为 one-hot 格式
135
+ for t, p in enumerate(pinyin_list):
136
+ if p in pinyin2id:
137
+ obs_seq[t, pinyin2id[p]] = 1
138
+ else:
139
+ obs_seq[t, 0] = 1 # 未知拼音默认处理
140
+
141
+ # 解码预测
142
+ model.n_trials = n_trials
143
+ log_prob, state_seq = model.decode(obs_seq, algorithm=model.algorithm)
144
+ result = ''.join([id2hanzi[s] for s in state_seq])
145
  print('预测结果:', result)
146
 
147
  if __name__ == '__main__':
148
+ # train(dataset_path='train.csv', model_path='hmm_model_large.pkl.bz2')
149
+ pred(model_path='hmm_model_large.pkl.bz2', pinyin_str='hong2 yan2 bo2 ming4') # 预测结果:红颜薄命