hc
2024-05-10 37f49e37ab4cb5d0bc4c60eb5c6d4dd57db767bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package com.rockchip.gpadc.demo.yolo;
 
import android.graphics.RectF;
import android.util.Log;
 
import com.rockchip.gpadc.demo.InferenceResult;
import com.rockchip.gpadc.demo.InferenceResult.OutputBuffer;
import com.rockchip.gpadc.demo.InferenceResult.Recognition;
import com.rockchip.gpadc.demo.InferenceResult.DetectResultGroup;
 
import java.io.IOException;
import java.util.ArrayList;
 
/**
 * Created by randall on 18-4-18.
 */
 
public class InferenceWrapper {
    private final String TAG = "rkyolo.InferenceWrapper";
 
    static {
        System.loadLibrary("rknn4j");
    }
 
    OutputBuffer mOutputs;
    ArrayList<Recognition> mRecognitions = new ArrayList<Recognition>();
    DetectResultGroup mDetectResults;
 
    public int OBJ_NUMB_MAX_SIZE = 64;
//    public int inf_count = 0;
//    public int post_count = 0;
//    public long inf_time = 0;
//    public long post_time = 0;
 
 
    public InferenceWrapper() {
 
    }
 
    public int initModel(int im_height, int im_width, int im_channel, String modelPath) throws Exception {
        mOutputs = new InferenceResult.OutputBuffer();
        mOutputs.mGrid0Out = new byte[255 * 80 * 80];
        mOutputs.mGrid1Out = new byte[255 * 40 * 40];
        mOutputs.mGrid2Out = new byte[255 * 20 * 20];
        if (navite_init(im_height, im_width, im_channel, modelPath) != 0) {
            throw new IOException("rknn init fail!");
        }
        return 0;
    }
 
 
    public void deinit() {
        native_deinit();
        mOutputs.mGrid0Out = null;
        mOutputs.mGrid1Out = null;
        mOutputs.mGrid2Out = null;
        mOutputs = null;
 
    }
 
    public InferenceResult.OutputBuffer run(byte[] inData) {
//        long startTime = System.currentTimeMillis();
//        long endTime;
        native_run(inData, mOutputs.mGrid0Out, mOutputs.mGrid1Out, mOutputs.mGrid2Out);
//        this.inf_count += 1;
//        endTime = System.currentTimeMillis();
//        this.inf_time += (endTime - startTime);
//        if (this.inf_count >= 100) {
//            float inf_avg = this.inf_time * 1.0f / this.inf_count;
//            Log.w(TAG, String.format("inference avg cost: %.5f ms", inf_avg));
//            this.inf_count = 0;
//            this.inf_time = 0;
//        }
//        Log.i(TAG, String.format("inference count: %d", this.inf_count));
        return  mOutputs;
    }
 
    public ArrayList<InferenceResult.Recognition> postProcess(InferenceResult.OutputBuffer outputs) {
        ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
 
        mDetectResults = new DetectResultGroup();
        mDetectResults.count = 0;
        mDetectResults.ids = new int[OBJ_NUMB_MAX_SIZE];
        mDetectResults.scores = new float[OBJ_NUMB_MAX_SIZE];
        mDetectResults.boxes = new float[4 * OBJ_NUMB_MAX_SIZE];
 
        if (null == outputs || null == outputs.mGrid0Out || null == outputs.mGrid1Out
                || null == outputs.mGrid2Out) {
            return recognitions;
        }
 
//        long startTime = System.currentTimeMillis();
//        long endTime;
        int count = native_post_process(outputs.mGrid0Out, outputs.mGrid1Out, outputs.mGrid2Out,
                mDetectResults.ids, mDetectResults.scores, mDetectResults.boxes);
        if (count < 0) {
            Log.w(TAG, "post_process may fail.");
            mDetectResults.count = 0;
        } else {
            mDetectResults.count = count;
        }
//        Log.i(TAG, String.format("Detected %d objects", count));
//        this.post_count += 1;
//        Log.i(TAG, String.format("post count: %d", this.post_count));
 
        for (int i = 0; i < count; ++i) {
            RectF rect = new RectF();
            rect.left = mDetectResults.boxes[i*4+0];
            rect.top = mDetectResults.boxes[i*4+1];
            rect.right = mDetectResults.boxes[i*4+2];
            rect.bottom = mDetectResults.boxes[i*4+3];
 
            Recognition recog = new InferenceResult.Recognition(mDetectResults.ids[i],
                    mDetectResults.scores[i], rect);
            recognitions.add(recog);
        }
//        endTime = System.currentTimeMillis();
//        this.post_time += (endTime - startTime);
//        if (this.post_count >= 100) {
//            float post_avg = this.post_time * 1.0f / this.post_count;
//            Log.w(TAG, String.format("post process avg cost: %.5f ms", post_avg));
//            this.post_time = 0;
//            this.post_count = 0;
//        }
 
        return recognitions;
    }
 
    private native int navite_init(int im_height, int im_width, int im_channel, String modelPath);
    private native void native_deinit();
    private native int native_run(byte[] inData, byte[] grid0Out, byte[] grid1Out, byte[] grid2Out);
    private native int native_post_process(byte[] grid0Out, byte[] grid1Out, byte[] grid2Out,
                                           int[] ids, float[] scores, float[] boxes);
 
}