引入Redis|tensorflow实现 聊天AI–PigPig养成记(3)

30次阅读

共计 3568 个字符,预计需要花费 9 分钟才能阅读完成。

引入 Redis
项目 github 链接
在集成 Netty 之后,为了提高效率,我打算将消息存储在 Redis 缓存系统中,本节将介绍 Redis 在项目中的引入,以及前端界面的开发。
引入 Redis 后,完整代码链接。
想要直接得到训练了 13000 步的聊天机器人可以直接下载链接中这三个文件,以及词汇表文件然后直接运行连接中的 py 脚本进行测试即可。
最终实现效果如下:

在 Netty 中引入 Redis
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.time.LocalDateTime;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.GlobalEventExecutor;
import redis.clients.jedis.Jedis;

public class ChatHandler
extends SimpleChannelInboundHandler<TextWebSocketFrame>{
private static ChannelGroup clients=
new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);

@Override
protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {
System.out.println(“channelRead0…”);

// 连接 redis
Jedis jedis=new Jedis(“localhost”);
System.out.println(“ 连接成功 …”);
System.out.println(“ 服务正在运行:”+jedis.ping());

// 得到用户输入的消息,需要写入文件 / 缓存中,让 AI 进行读取
String content=msg.text();
if(content==null||content==””) {
System.out.println(“content 为 null”);
return ;
}
System.out.println(“ 接收到的消息:”+content);

// 写入缓存中
jedis.set(“user_say”, content+”:user”);

Thread.sleep(1000);
// 读取 AI 返回的内容
String AIsay=null;
while(AIsay==”no”||AIsay==null) {
// 从缓存中读取 AI 回复的内容
AIsay=jedis.get(“ai_say”);
String [] arr=AIsay.split(“:”);
AIsay=arr[0];
}

// 读取后马上向缓存中写入
jedis.set(“ai_say”, “no”);
// 没有说,或者还没说
if(AIsay==null||AIsay==””) {
System.out.println(“AIsay==null||AIsay==\”\””);
return;
}
System.out.println(“AI 说:”+AIsay);

clients.writeAndFlush(
new TextWebSocketFrame(
“AI_PigPig 在 ”+LocalDateTime.now()
+” 说:”+AIsay));
}

@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
System.out.println(“add…”);
clients.add(ctx.channel());
}

@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
System.out.println(“ 客户端断开,channel 对应的长 id 为:”
+ctx.channel().id().asLongText());
System.out.println(“ 客户端断开,channel 对应的短 id 为:”
+ctx.channel().id().asShortText());
}

}

在 Python 中引入 Redis
with tf.Session() as sess:# 打开作为一次会话
# 恢复前一次训练
ckpt = tf.train.get_checkpoint_state(‘.’)# 从检查点文件中返回一个状态 (ckpt)
#如果 ckpt 存在,输出模型路径
if ckpt != None:
print(ckpt.model_checkpoint_path)
model.saver.restore(sess, ckpt.model_checkpoint_path)# 储存模型参数
else:
print(“ 没找到模型 ”)
r.set(‘user_say’,’no’)
#测试该模型的能力
while True:
line=’no’
#从缓存中进行读取
while line==’no’:
line=r.get(‘user_say’).decode()
#print(line)
list1=line.split(‘:’)
if len(list1)==1:
input_string=’no’
else:
input_string=list1[0]
r.set(‘user_say’,’no’)

# 退出
if input_string == ‘quit’:
exit()
if input_string != ‘no’:
input_string_vec = []# 输入字符串向量化
for words in input_string.strip():
input_string_vec.append(vocab_en.get(words, UNK_ID))#get() 函数:如果 words 在词表中,返回索引号;否则,返回 UNK_ID
bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])# 保留最小的大于输入的 bucket 的 id
encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id)
#get_batch(A,B): 两个参数,A 为大小为 len(buckets) 的元组,返回了指定 bucket_id 的 encoder_inputs,decoder_inputs,target_weights
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
#得到其输出
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]# 求得最大的预测范围列表
if EOS_ID in outputs:# 如果 EOS_ID 在输出内部,则输出列表为 [,,,,:End]
outputs = outputs[:outputs.index(EOS_ID)]

response = “”.join([tf.compat.as_str(vocab_de[output]) for output in outputs])# 转为解码词汇分别添加到回复中
print(‘AI-PigPig > ‘ + response)# 输出回复
#向缓存中进行写入
r.set(‘ai_say’,response+’:AI’)

下一节将讲述通信规则的制定,以规范应用程序。

正文完
 0