/*
|
* Copyright (C) 2017 The Android Open Source Project
|
*
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
* you may not use this file except in compliance with the License.
|
* You may obtain a copy of the License at
|
*
|
* http://www.apache.org/licenses/LICENSE-2.0
|
*
|
* Unless required by applicable law or agreed to in writing, software
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* See the License for the specific language governing permissions and
|
* limitations under the License.
|
*/
|
|
#define LOG_TAG "DnsTlsTransport"
|
//#define LOG_NDEBUG 0
|
|
#include "DnsTlsTransport.h"
|
|
#include <arpa/inet.h>
|
#include <arpa/nameser.h>
|
|
#include "DnsTlsSocketFactory.h"
|
#include "IDnsTlsSocketFactory.h"
|
|
#include "log/log.h"
|
|
namespace android {
|
namespace net {
|
|
std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
|
std::lock_guard guard(mLock);
|
|
auto record = mQueries.recordQuery(query);
|
if (!record) {
|
return std::async(std::launch::deferred, []{
|
return (Result) { .code = Response::internal_error };
|
});
|
}
|
|
if (!mSocket) {
|
ALOGV("No socket for query. Opening socket and sending.");
|
doConnect();
|
} else {
|
sendQuery(record->query);
|
}
|
|
return std::move(record->result);
|
}
|
|
bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query q) {
|
// Strip off the ID number and send the new ID instead.
|
bool sent = mSocket->query(q.newId, netdutils::drop(q.query, 2));
|
if (sent) {
|
mQueries.markTried(q.newId);
|
}
|
return sent;
|
}
|
|
void DnsTlsTransport::doConnect() {
|
ALOGV("Constructing new socket");
|
mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);
|
|
if (mSocket) {
|
auto queries = mQueries.getAll();
|
ALOGV("Initialization succeeded. Reissuing %zu queries.", queries.size());
|
for(auto& q : queries) {
|
if (!sendQuery(q)) {
|
break;
|
}
|
}
|
} else {
|
ALOGV("Initialization failed.");
|
mSocket.reset();
|
ALOGV("Failing all pending queries.");
|
mQueries.clear();
|
}
|
}
|
|
void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
|
mQueries.onResponse(std::move(response));
|
}
|
|
void DnsTlsTransport::onClosed() {
|
std::lock_guard guard(mLock);
|
if (mClosing) {
|
return;
|
}
|
// Move remaining operations to a new thread.
|
// This is necessary because
|
// 1. onClosed is currently running on a thread that blocks mSocket's destructor
|
// 2. doReconnect will call that destructor
|
if (mReconnectThread) {
|
// Complete cleanup of a previous reconnect thread, if present.
|
mReconnectThread->join();
|
// Joining a thread that is trying to acquire mLock, while holding mLock,
|
// looks like it risks a deadlock. However, a deadlock will not occur because
|
// once onClosed is called, it cannot be called again until after doReconnect
|
// acquires mLock.
|
}
|
mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
|
}
|
|
void DnsTlsTransport::doReconnect() {
|
std::lock_guard guard(mLock);
|
if (mClosing) {
|
return;
|
}
|
mQueries.cleanup();
|
if (!mQueries.empty()) {
|
ALOGV("Fast reconnect to retry remaining queries");
|
doConnect();
|
} else {
|
ALOGV("No pending queries. Going idle.");
|
mSocket.reset();
|
}
|
}
|
|
DnsTlsTransport::~DnsTlsTransport() {
|
ALOGV("Destructor");
|
{
|
std::lock_guard guard(mLock);
|
ALOGV("Locked destruction procedure");
|
mQueries.clear();
|
mClosing = true;
|
}
|
// It's possible that a reconnect thread was spawned and waiting for mLock.
|
// It's safe for that thread to run now because mClosing is true (and mQueries is empty),
|
// but we need to wait for it to finish before allowing destruction to proceed.
|
if (mReconnectThread) {
|
ALOGV("Waiting for reconnect thread to terminate");
|
mReconnectThread->join();
|
mReconnectThread.reset();
|
}
|
// Ensure that the socket is destroyed, and can clean up its callback threads,
|
// before any of this object's fields become invalid.
|
mSocket.reset();
|
ALOGV("Destructor completed");
|
}
|
|
// static
|
// TODO: Use this function to preheat the session cache.
|
// That may require moving it to DnsTlsDispatcher.
|
bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid, uint32_t mark) {
|
ALOGV("Beginning validation on %u", netid);
|
// Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
|
// order to prove that it is actually a working DNS over TLS server.
|
static const char kDnsSafeChars[] =
|
"abcdefhijklmnopqrstuvwxyz"
|
"ABCDEFHIJKLMNOPQRSTUVWXYZ"
|
"0123456789";
|
const auto c = [](uint8_t rnd) -> uint8_t {
|
return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
|
};
|
uint8_t rnd[8];
|
arc4random_buf(rnd, std::size(rnd));
|
// We could try to use res_mkquery() here, but it's basically the same.
|
uint8_t query[] = {
|
rnd[6], rnd[7], // [0-1] query ID
|
1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
|
0, 1, // [4-5] QDCOUNT (number of queries)
|
0, 0, // [6-7] ANCOUNT (number of answers)
|
0, 0, // [8-9] NSCOUNT (number of name server records)
|
0, 0, // [10-11] ARCOUNT (number of additional records)
|
17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
|
'-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
|
6, 'm', 'e', 't', 'r', 'i', 'c',
|
7, 'g', 's', 't', 'a', 't', 'i', 'c',
|
3, 'c', 'o', 'm',
|
0, // null terminator of FQDN (root TLD)
|
0, ns_t_aaaa, // QTYPE
|
0, ns_c_in // QCLASS
|
};
|
const int qlen = std::size(query);
|
|
int replylen = 0;
|
DnsTlsSocketFactory factory;
|
DnsTlsTransport transport(server, mark, &factory);
|
auto r = transport.query(netdutils::Slice(query, qlen)).get();
|
if (r.code != Response::success) {
|
ALOGV("query failed");
|
return false;
|
}
|
|
const std::vector<uint8_t>& recvbuf = r.response;
|
if (recvbuf.size() < NS_HFIXEDSZ) {
|
ALOGW("short response: %d", replylen);
|
return false;
|
}
|
|
const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
|
if (qdcount != 1) {
|
ALOGW("reply query count != 1: %d", qdcount);
|
return false;
|
}
|
|
const int ancount = (recvbuf[6] << 8) | recvbuf[7];
|
ALOGV("%u answer count: %d", netid, ancount);
|
|
// TODO: Further validate the response contents (check for valid AAAA record, ...).
|
// Note that currently, integration tests rely on this function accepting a
|
// response with zero records.
|
#if 0
|
for (int i = 0; i < resplen; i++) {
|
ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
|
}
|
#endif
|
return true;
|
}
|
|
} // end of namespace net
|
} // end of namespace android
|