package com.rockchip.gpadc.demo.yolo; import android.graphics.RectF; import android.util.Log; import com.rockchip.gpadc.demo.InferenceResult; import com.rockchip.gpadc.demo.InferenceResult.OutputBuffer; import com.rockchip.gpadc.demo.InferenceResult.Recognition; import com.rockchip.gpadc.demo.InferenceResult.DetectResultGroup; import java.io.IOException; import java.util.ArrayList; /** * Created by randall on 18-4-18. */ public class InferenceWrapper { private final String TAG = "rkyolo.InferenceWrapper"; static { System.loadLibrary("rknn4j"); } OutputBuffer mOutputs; ArrayList mRecognitions = new ArrayList(); DetectResultGroup mDetectResults; public int OBJ_NUMB_MAX_SIZE = 64; // public int inf_count = 0; // public int post_count = 0; // public long inf_time = 0; // public long post_time = 0; public InferenceWrapper() { } public int initModel(int im_height, int im_width, int im_channel, String modelPath) throws Exception { mOutputs = new InferenceResult.OutputBuffer(); mOutputs.mGrid0Out = new byte[255 * 80 * 80]; mOutputs.mGrid1Out = new byte[255 * 40 * 40]; mOutputs.mGrid2Out = new byte[255 * 20 * 20]; if (navite_init(im_height, im_width, im_channel, modelPath) != 0) { throw new IOException("rknn init fail!"); } return 0; } public void deinit() { native_deinit(); mOutputs.mGrid0Out = null; mOutputs.mGrid1Out = null; mOutputs.mGrid2Out = null; mOutputs = null; } public InferenceResult.OutputBuffer run(byte[] inData) { // long startTime = System.currentTimeMillis(); // long endTime; native_run(inData, mOutputs.mGrid0Out, mOutputs.mGrid1Out, mOutputs.mGrid2Out); // this.inf_count += 1; // endTime = System.currentTimeMillis(); // this.inf_time += (endTime - startTime); // if (this.inf_count >= 100) { // float inf_avg = this.inf_time * 1.0f / this.inf_count; // Log.w(TAG, String.format("inference avg cost: %.5f ms", inf_avg)); // this.inf_count = 0; // this.inf_time = 0; // } // Log.i(TAG, String.format("inference count: %d", this.inf_count)); return mOutputs; } public ArrayList postProcess(InferenceResult.OutputBuffer outputs) { ArrayList recognitions = new ArrayList(); mDetectResults = new DetectResultGroup(); mDetectResults.count = 0; mDetectResults.ids = new int[OBJ_NUMB_MAX_SIZE]; mDetectResults.scores = new float[OBJ_NUMB_MAX_SIZE]; mDetectResults.boxes = new float[4 * OBJ_NUMB_MAX_SIZE]; if (null == outputs || null == outputs.mGrid0Out || null == outputs.mGrid1Out || null == outputs.mGrid2Out) { return recognitions; } // long startTime = System.currentTimeMillis(); // long endTime; int count = native_post_process(outputs.mGrid0Out, outputs.mGrid1Out, outputs.mGrid2Out, mDetectResults.ids, mDetectResults.scores, mDetectResults.boxes); if (count < 0) { Log.w(TAG, "post_process may fail."); mDetectResults.count = 0; } else { mDetectResults.count = count; } // Log.i(TAG, String.format("Detected %d objects", count)); // this.post_count += 1; // Log.i(TAG, String.format("post count: %d", this.post_count)); for (int i = 0; i < count; ++i) { RectF rect = new RectF(); rect.left = mDetectResults.boxes[i*4+0]; rect.top = mDetectResults.boxes[i*4+1]; rect.right = mDetectResults.boxes[i*4+2]; rect.bottom = mDetectResults.boxes[i*4+3]; Recognition recog = new InferenceResult.Recognition(mDetectResults.ids[i], mDetectResults.scores[i], rect); recognitions.add(recog); } // endTime = System.currentTimeMillis(); // this.post_time += (endTime - startTime); // if (this.post_count >= 100) { // float post_avg = this.post_time * 1.0f / this.post_count; // Log.w(TAG, String.format("post process avg cost: %.5f ms", post_avg)); // this.post_time = 0; // this.post_count = 0; // } return recognitions; } private native int navite_init(int im_height, int im_width, int im_channel, String modelPath); private native void native_deinit(); private native int native_run(byte[] inData, byte[] grid0Out, byte[] grid1Out, byte[] grid2Out); private native int native_post_process(byte[] grid0Out, byte[] grid1Out, byte[] grid2Out, int[] ids, float[] scores, float[] boxes); }