huangcm
2025-02-24 69ed55dec4b2116a19e4cca4393cbc014fce5fb2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
 
#define LOG_TAG "DnsTlsTransport"
//#define LOG_NDEBUG 0
 
#include "DnsTlsTransport.h"
 
#include <arpa/inet.h>
#include <arpa/nameser.h>
 
#include "DnsTlsSocketFactory.h"
#include "IDnsTlsSocketFactory.h"
 
#include "log/log.h"
 
namespace android {
namespace net {
 
std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
    std::lock_guard guard(mLock);
 
    auto record = mQueries.recordQuery(query);
    if (!record) {
        return std::async(std::launch::deferred, []{
            return (Result) { .code = Response::internal_error };
        });
    }
 
    if (!mSocket) {
        ALOGV("No socket for query.  Opening socket and sending.");
        doConnect();
    } else {
        sendQuery(record->query);
    }
 
    return std::move(record->result);
}
 
bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query q) {
    // Strip off the ID number and send the new ID instead.
    bool sent = mSocket->query(q.newId, netdutils::drop(q.query, 2));
    if (sent) {
        mQueries.markTried(q.newId);
    }
    return sent;
}
 
void DnsTlsTransport::doConnect() {
    ALOGV("Constructing new socket");
    mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);
 
    if (mSocket) {
        auto queries = mQueries.getAll();
        ALOGV("Initialization succeeded.  Reissuing %zu queries.", queries.size());
        for(auto& q : queries) {
            if (!sendQuery(q)) {
                break;
            }
        }
    } else {
        ALOGV("Initialization failed.");
        mSocket.reset();
        ALOGV("Failing all pending queries.");
        mQueries.clear();
    }
}
 
void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
    mQueries.onResponse(std::move(response));
}
 
void DnsTlsTransport::onClosed() {
    std::lock_guard guard(mLock);
    if (mClosing) {
        return;
    }
    // Move remaining operations to a new thread.
    // This is necessary because
    // 1. onClosed is currently running on a thread that blocks mSocket's destructor
    // 2. doReconnect will call that destructor
    if (mReconnectThread) {
        // Complete cleanup of a previous reconnect thread, if present.
        mReconnectThread->join();
        // Joining a thread that is trying to acquire mLock, while holding mLock,
        // looks like it risks a deadlock.  However, a deadlock will not occur because
        // once onClosed is called, it cannot be called again until after doReconnect
        // acquires mLock.
    }
    mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
}
 
void DnsTlsTransport::doReconnect() {
    std::lock_guard guard(mLock);
    if (mClosing) {
        return;
    }
    mQueries.cleanup();
    if (!mQueries.empty()) {
        ALOGV("Fast reconnect to retry remaining queries");
        doConnect();
    } else {
        ALOGV("No pending queries.  Going idle.");
        mSocket.reset();
    }
}
 
DnsTlsTransport::~DnsTlsTransport() {
    ALOGV("Destructor");
    {
        std::lock_guard guard(mLock);
        ALOGV("Locked destruction procedure");
        mQueries.clear();
        mClosing = true;
    }
    // It's possible that a reconnect thread was spawned and waiting for mLock.
    // It's safe for that thread to run now because mClosing is true (and mQueries is empty),
    // but we need to wait for it to finish before allowing destruction to proceed.
    if (mReconnectThread) {
        ALOGV("Waiting for reconnect thread to terminate");
        mReconnectThread->join();
        mReconnectThread.reset();
    }
    // Ensure that the socket is destroyed, and can clean up its callback threads,
    // before any of this object's fields become invalid.
    mSocket.reset();
    ALOGV("Destructor completed");
}
 
// static
// TODO: Use this function to preheat the session cache.
// That may require moving it to DnsTlsDispatcher.
bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid, uint32_t mark) {
    ALOGV("Beginning validation on %u", netid);
    // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
    // order to prove that it is actually a working DNS over TLS server.
    static const char kDnsSafeChars[] =
            "abcdefhijklmnopqrstuvwxyz"
            "ABCDEFHIJKLMNOPQRSTUVWXYZ"
            "0123456789";
    const auto c = [](uint8_t rnd) -> uint8_t {
        return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
    };
    uint8_t rnd[8];
    arc4random_buf(rnd, std::size(rnd));
    // We could try to use res_mkquery() here, but it's basically the same.
    uint8_t query[] = {
        rnd[6], rnd[7],  // [0-1]   query ID
        1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
        0, 1,  // [4-5]   QDCOUNT (number of queries)
        0, 0,  // [6-7]   ANCOUNT (number of answers)
        0, 0,  // [8-9]   NSCOUNT (number of name server records)
        0, 0,  // [10-11] ARCOUNT (number of additional records)
        17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
            '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
        6, 'm', 'e', 't', 'r', 'i', 'c',
        7, 'g', 's', 't', 'a', 't', 'i', 'c',
        3, 'c', 'o', 'm',
        0,  // null terminator of FQDN (root TLD)
        0, ns_t_aaaa,  // QTYPE
        0, ns_c_in     // QCLASS
    };
    const int qlen = std::size(query);
 
    int replylen = 0;
    DnsTlsSocketFactory factory;
    DnsTlsTransport transport(server, mark, &factory);
    auto r = transport.query(netdutils::Slice(query, qlen)).get();
    if (r.code != Response::success) {
        ALOGV("query failed");
        return false;
    }
 
    const std::vector<uint8_t>& recvbuf = r.response;
    if (recvbuf.size() < NS_HFIXEDSZ) {
        ALOGW("short response: %d", replylen);
        return false;
    }
 
    const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
    if (qdcount != 1) {
        ALOGW("reply query count != 1: %d", qdcount);
        return false;
    }
 
    const int ancount = (recvbuf[6] << 8) | recvbuf[7];
    ALOGV("%u answer count: %d", netid, ancount);
 
    // TODO: Further validate the response contents (check for valid AAAA record, ...).
    // Note that currently, integration tests rely on this function accepting a
    // response with zero records.
#if 0
    for (int i = 0; i < resplen; i++) {
        ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
    }
#endif
    return true;
}
 
}  // end of namespace net
}  // end of namespace android