关于人工智能:聚会必备-使用PyTorch和DJL开发一个我画你猜小游戏

63次阅读

共计 5695 个字符,预计需要花费 15 分钟才能阅读完成。

春节假期邻近,亲朋小聚饭后不可避免须要来点乏味的事件打发工夫,「我画你猜」就是一种很好的消遣形式。然而往年既然提倡「就地过年」,那么无妨就把这样的游戏搬到网上吧,照样能够玩到嗨~

一些童鞋可能还有印象,2018 年时,Google 推出了《猜画小歌》利用:玩家能够间接与 AI 进行你画我猜的游戏。通过画出一个房子或者一个猫,AI 会推断出各种物品被画出的概率。它的实现得益于深度学习模型在其中的利用,通过深度神经网络的演绎,已经令人头疼的绘画辨认也变得大海捞针。现如今,只有应用一个简略的图片分类模型,咱们便能够轻松的实现绘画辨认。试试看这个在线涂鸦小游戏吧。

在过后,大部分机器学习计算工作仍旧须要依靠网络在云端进行。随着算力的一直增进,机器学习工作曾经能够间接在边缘设施部署,包含各类运行安卓零碎的智能手机。然而,因为安卓自身次要是用 Java,部署基于 Python 的各类深度学习模型变成了一个难题。为了解决这个问题,AWS 开发并开源了 DeepJavaLibrary (DJL),一个为 Java 量身定制的深度学习框架。

在下文中,咱们将尝试通过 PyTorch 预训练模型在在安卓平台构建一个涂鸦绘画的利用。因为总代码量会比拟多,咱们这次会挑重点把最要害的代码实现。大家能够后续参考咱们残缺的我的项目进行构建。

环境配置

为了兼容 DJL 需要的 Java 性能,这个我的项目须要 Android API 26 及以上的版本。能够参考咱们案例配置来节约一些工夫,上面是这个我的项目须要的依赖项:

dependencies {
 implementation 'androidx.appcompat:appcompat:1.2.0'
 implementation 'ai.djl:api:0.7.0'
 implementation 'ai.djl.android:core:0.7.0'
 runtimeOnly 'ai.djl.pytorch:pytorch-engine:0.7.0'
 runtimeOnly 'ai.djl.android:pytorch-native:0.7.0'

咱们将应用 DJL 提供的 API 以及 PyTorch 包。

第一步:创立 Layout

咱们能够先创立一个 View class 以及 Layout(如下图)来构建安卓的前端显示界面。

如上图所示,咱们能够在主界面创立两个 View 指标。PaintView 是用来让用户画画的,在右下角 ImageView 是用来展现用于深度学习推理的图像。同时咱们预留一个按钮来进行画板的清空操作。

第二部:应答绘画动作

在安卓设施上,咱们能够自定义安卓的触摸事件响应来应答用户的各种触控操作。在咱们的状况下,咱们须要定义上面三种工夫响应:

  • touchStart:感应触碰时触发
  • touchMove:当用户在屏幕上挪动手指时触发
  • touchUp:当用户抬起手指时触发

与此同时,咱们用 paths 来存储用户在画板所绘制的门路。当初看一下实现代码。

重写 OnTouchEvent 和 OnDraw 办法

当初咱们重写 onTouchEvent 来应答各种响应:

@Override
public boolean onTouchEvent(MotionEvent event) {float x = event.getX();
 float y = event.getY();
 switch (event.getAction()) {
 case MotionEvent.ACTION_DOWN :
 touchStart(x, y);
 invalidate();
 break;
 case MotionEvent.ACTION_MOVE :
 touchMove(x, y);
 invalidate();
 break;
 case MotionEvent.ACTION_UP :
 touchUp();
 runInference();
 invalidate();
 break;
 }
 return true;
}

如上述代码所示,咱们能够增加一个 runInference 办法在 MotionEvent.ACTION_UP 事件响应上。这个办法是用来在用户绘制完后对后果进行推理。在之后的几步中,咱们会解说它的具体实现。

咱们同样须要重写 onDraw 办法来展现用户绘制的图像:

@Override
protected void onDraw(Canvas canvas) {canvas.save();
 this.canvas.drawColor(DEFAULT_BG_COLOR);
 for (Path path : paths) {paint.setColor(DEFAULT_PAINT_COLOR);
 paint.setStrokeWidth(BRUSH_SIZE);
 this.canvas.drawPath(path, paint);
 }
 canvas.drawBitmap(bitmap, 0, 0, bitmapPaint);
 canvas.restore();}

真正的图像会保留在一个 Bitmap 上。

touchStart

当用户触碰行为开始时,上面的代码会建设一个新的门路同时记录门路中每一个点在屏幕上的坐标。

private void touchStart(float x, float y) {path = new Path();
 paths.add(path);
 path.reset();
 path.moveTo(x, y);
 this.x = x;
 this.y = y;
}

touchMove

在手指挪动中,咱们会继续记录坐标点而后将它们形成一个 quadratic bezier)。通过肯定的误差阀值来动静优化用户的绘画动作。只有差异超出误差范畴内的动作才会被记录下来。

private void touchMove(float x, float y) {if (x < 0 || x > getWidth() || y < 0 || y > getHeight()) {return;}
 float dx = Math.abs(x - this.x);
 float dy = Math.abs(y - this.y);
 if (dx >= TOUCH_TOLERANCE || dy >= TOUCH_TOLERANCE) {path.quadTo(this.x, this.y, (x + this.x) / 2, (y + this.y) / 2);
 this.x = x;
 this.y = y;
 }
}

touchUp

当触控操作完结后,上面的代码会绘制一个门路同时计算最小长方形指标框。

private void touchUp() {path.lineTo(this.x, this.y);
 maxBound.add(new Path(path));
}

Step 3:开始推理

为了在安卓设施上进行推理工作,咱们须要实现上面几个工作:

  • 从 URL 读取模型
  • 构建前解决和后处理过程
  • 从 PaintView 进行推理工作

为了实现以下指标,咱们尝试构建一个 DoodleModel class。在这一步,咱们将介绍一些实现这些工作的关键步骤。

读取模型

DJL 内建了一套模型管理系统。开发者能够自定义贮存模型的文件夹。

File dir = getFilesDir();
System.setProperty("DJL_CACHE_DIR", dir.getAbsolutePath());

通过更改 DJL_CACHE_DIR 属性,模型会被存入相应门路下。

下一步能够通过定义 Criteria 从指定 URL 处下载模型。下载的 zip 文件内蕴含:

  • doodle_mobilenet.pt:PyTorch 模型
  • synset.txt:内蕴含分类工作中所有类别的名称
Criteria<Image, Classifications> criteria =
 Criteria.builder()
 .setTypes(Image.class, Classifications.class)
 .optModelUrls("https://djl-ai.s3.amazonaws.com/resources/demo/pytorch/doodle_mobilenet.zip")
 .optTranslator(translator)
 .build();
return ModelZoo.loadModel(criteria);

上述代码同时定义了 translator。translator 会被用来做图片的前解决和后处理。

最初,如下述代码创立一个 Model 并用它创立一个 Predictor:

@Override
protected Boolean doInBackground(Void... params) {
 try {model = DoodleModel.loadModel();
 predictor = model.newPredictor();
 return true;
 } catch (IOException | ModelException e) {Log.e("DoodleDraw", null, e);
 }
 return false;
}

更多对于模型加载的信息,请参阅如何加载模型。

用 Translator 定义前解决和后处理

在 DJL 中,咱们定义了 Translator 接口进行前解决和后处理。在 DoodleModel 中咱们定义了 ImageClassificationTranslator 来实现 Translator:

ImageClassificationTranslator.builder()
 .addTransform(new ToTensor())
 .optFlag(Image.Flag.GRAYSCALE)
 .optApplySoftmax(true).build());

上面咱们具体论述 translator 所定义的前解决和后处理如何被用在模型的推理步骤中。当创立 translator 时,外部程序会主动加载 synset.txt 文件失去做分类工作时所有类别的名称。当模型的 predict () 办法被调用时,外部程序会先执行所对应的 translator 的前解决步骤,而后执行理论推理步骤,最初执行 translator 的后处理步骤。对于前解决,咱们会将 Image 转化 NDArray,用于作为模型推理过程的输出。对于后处理,咱们对推理输入的后果(NDArray)进行 softmax 操作。最终返回后果为 Classifications 的一个实例。

更多对于 translator 的工作原理以及如何个性化 Translator 的信息,请参阅 Inference with your model。

Run inference from PaintView

最初,咱们来实现之前定义好的 runInference 办法。

public void runInference() {
 // 拷贝图像
 Bitmap bmp = Bitmap.createBitmap(bitmap);
 // 缩放图像
 bmp = Bitmap.createScaledBitmap(bmp, 64, 64, true);
 // 执行推理工作
 Classifications classifications = model.predict(bmp);
 // 展现输出的图像
 Bitmap present = Bitmap.createScaledBitmap(bmp, imageView.getWidth(), imageView.getHeight(), true);
 imageView.setImageBitmap(present);
 // 展现输入的图像
 if (messageToast != null) {messageToast.cancel();
 }
 messageToast = Toast.makeText(getContext(), classifications.toString(), Toast.LENGTH_SHORT);
 messageToast.show();}

这将会创立一个 Toast 弹出页面用于展现后果,示例如下:

祝贺你!当初你就创立了一个残缺的 Doodle Draw 小程序!

Optional: Optimize input

为了失去更高的模型推理准确度,能够通过截取图像来去除无意义的边框局部。

下面右侧的图片会比右边的图片有更好的推理后果,因为它所蕴含的空白边框更少。咱们能够通过 Bound 类来寻找图片的无效边界,即能把图中所有红色像素点笼罩的最小矩形。在失去 x 轴最左坐标,y 轴最上坐标,以及矩形高度和宽度后,就能够用这些信息截取出咱们想要的图形(如右图所示)实现代码如下:

RectF bound = maxBound.getBound();
int x = (int) bound.left;
int y = (int) bound.top;
int width = (int) Math.ceil(bound.width());
int height = (int) Math.ceil(bound.height());
// 截取局部图像
Bitmap bmp = Bitmap.createBitmap(bitmap, x, y, width, height)

祝贺你!当初你就把握了全副教程内容!期待看到你创立的第一个 DoodleDraw 安卓游戏!

最初,能够在 GitHub 找到本教程的残缺案例代码。

对于 Deep Java Library

Deep Java Library (DJL) 是一个基于 Java 的深度学习框架,同时反对训练以及推理。DJL 博取众长,构建在多个深度学习框架之上(TenserFlow、PyTorch、MXNet 等),也同时具备多个框架的低劣个性。咱们能够轻松应用 DJL 来进行训练而后部署你的模型。

它同时领有着弱小的模型库反对:只需一行便能够轻松读取各种预训练的模型。当初 DJL 的模型库同时反对高达 70 个来自 GluonCV、HuggingFace、TorchHub 以及 Keras 的模型。

我的项目地址:https://github.com/awslabs/djl/

在最新的版本中 DJL 0.7.0 增加了对于 MXNet 1.7.0、PyTorch 1.6.0、TensorFlow 2.3.0 的反对。咱们同时也增加了 ONNXRuntime 以及 PyTorch 在安卓平台的反对。

请参阅咱们的 GitHub、demo repository、Slack channel 和知乎频道获取更多信息!

正文完
 0