Android Tensorflow Lite 使用的问题

it2024-02-22  69

说明

本文仅介绍在Android Studio 编程中使用TFLite的问题,关于Python的问题不会涉及;

将转换好的.tflite文件放到android工程目录下的assets中;加载文件的方法: private MappedByteBuffer loadModelFile(Activity activity) throws IOException { AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_NAME); //这里的MODEL_NAME是文件的名字,我的是net.tflite; FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } 调用运算的方法: //predict类的构造方法,(loadModelFile和 下面的classify都是Predict类中的成员方法 ): public Predict(Activity activity) throws IOException { mInterpreter = new Interpreter(loadModelFile(activity), mOptions); mECGData = ByteBuffer.allocateDirect(4*INPUT_LEN); //这里一定要用allocateDirect方法,不然会出现运算结果一致不变的问题 mECGData.order(ByteOrder.nativeOrder()); } public Result classify(XYSeries mXYSeries){ Log.d(LOG_TAG, String.valueOf(mXYSeries.getItemCount())); mECGData.rewind(); for(int i=0;i<INPUT_LEN;i++){ mECGData.putFloat((float) mXYSeries.getY(i)); // long startTime = SystemClock.uptimeMillis();//计算开始时间 mInterpreter.run(mECGData, mResult);//跑模型计算,输入mImageData,输出mResult; long endTime = SystemClock.uptimeMillis();//计算结束时间 long timeCost = endTime - startTime;//计算所用时间 Log.v(LOG_TAG, "classify(): result = " + Arrays.toString(mResult[0]) + ", timeCost = " + timeCost); return new Result(mResult[0], timeCost); } 使用的过程: 初始化Predict类,调用classify方法;

遇到的问题

添加依赖项:gradle:app implementation 'org.tensorflow:parentpom:1.14.0' implementation 'org.tensorflow:tensorflow-lite:+' 不压缩.tflite文件:gradle:app aaptOptions { noCompress "tflite" //表示不让aapt压缩的文件后缀 noCompress "lite" } //在android{}中添加 神经网络运算结果为固定值的问题:这个也是让我搞了半天才解决的问题,不知道为什么,在输入不同的数据时,分类的结果都是一样的;找了很久,也没有找到合理的解释。最后好像看到一篇博客说,ByteBuffer在申明初始化的时候,调用的方法很重要:interpreter.run方法好像大家用ByteBuffer比较多,我之前用的是FloatBuffer,运行结果不对,还有一定要用ByteBuffer的.allocateDirect方法,他还有一个allocate方法,也不太行; private ByteBuffer mECGData; mECGData = ByteBuffer.allocateDirect(4*INPUT_LEN); mECGData.order(ByteOrder.nativeOrder()); mECGData.rewind(); for(int i=0;i<INPUT_LEN;i++){ mECGData.putFloat((float) mXYSeries.getY(i)); } ``
最新回复(0)