liyujie
2025-08-28 786ff4f4ca2374bdd9177f2e24b503d43e7a3b93
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright 2014 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
 
import logging
import socket
import struct
import time
 
from autotest_lib.client.common_lib import error
from autotest_lib.client.common_lib.cros.network import interface
 
 
class InterfaceHost(object):
    """A host for use with ZeroconfDaemon that binds to an interface."""
 
    @property
    def ip_addr(self):
        """Get the IP address of the interface we're bound to."""
        return self._interface.ipv4_address
 
 
    def __init__(self, interface_name):
        self._interface = interface.Interface(interface_name)
        self._socket = None
 
 
    def close(self):
        """Close the underlying socket."""
        if self._socket:
            self._socket.close()
 
 
    def socket(self, family, sock_type):
        """Get a socket bound to this interface.
 
        Only supports IPv4 UDP sockets on broadcast addresses.
 
        @param family: must be socket.AF_INET.
        @param sock_type: must be socket.SOCK_DGRAM.
 
        """
        if family != socket.AF_INET or sock_type != socket.SOCK_DGRAM:
            raise error.TestError('InterfaceHost only understands UDP sockets.')
        if self._socket is not None:
            raise error.TestError('InterfaceHost only supports a single '
                                  'multicast socket.')
 
        self._socket = InterfaceDatagramSocket(self.ip_addr)
        return self._socket
 
 
    def run_until(self, predicate, timeout_seconds):
        """Handle traffic from our socket until |predicate|() is true.
 
        @param predicate: function without arguments that returns True or False.
        @param timeout_seconds: number of seconds to wait for predicate to
                                become True.
        @return: tuple(success, duration) where success is True iff predicate()
                 became true before |timeout_seconds| passed.
 
        """
        start_time = time.time()
        duration = lambda: time.time() - start_time
        while duration() < timeout_seconds:
            if predicate():
                return True, duration()
            # Assume this take non-trivial time, don't sleep here.
            self._socket.run_once()
        return False, duration()
 
 
class InterfaceDatagramSocket(object):
    """Broadcast UDP socket bound to a particular network interface."""
 
    # Wait for a UDP frame to appear for this long before timing out.
    TIMEOUT_VALUE_SECONDS = 0.5
 
    def __init__(self, interface_ip):
        """Construct an instance.
 
        @param interface_ip: string like '239.192.1.100'.
 
        """
        self._interface_ip = interface_ip
        self._recv_callback = None
        self._recv_sock = None
        self._send_sock = None
 
 
    def close(self):
        """Close state associated with this object."""
        if self._recv_sock is not None:
            # Closing the socket drops membership groups.
            self._recv_sock.close()
            self._recv_sock = None
        if self._send_sock is not None:
            self._send_sock.close()
            self._send_sock = None
 
 
    def listen(self, ip_addr, port, recv_callback):
        """Bind and listen on the ip_addr:port.
 
        @param ip_addr: Multicast group IP (e.g. '224.0.0.251')
        @param port: Local destination port number.
        @param recv_callback: A callback function that accepts three arguments,
                              the received string, the sender IPv4 address and
                              the sender port number.
 
        """
        if self._recv_callback is not None:
            raise error.TestError('listen() called twice on '
                                  'InterfaceDatagramSocket.')
        # Multicast addresses are in 224.0.0.0 - 239.255.255.255 (rfc5771)
        ip_addr_prefix = ord(socket.inet_aton(ip_addr)[0])
        if ip_addr_prefix < 224 or ip_addr_prefix > 239:
            raise error.TestError('Invalid multicast address.')
 
        self._recv_callback = recv_callback
        # Set up a socket to receive just traffic from the given address.
        self._recv_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self._recv_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self._recv_sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
                                   socket.inet_aton(ip_addr) +
                                   socket.inet_aton(self._interface_ip))
        self._recv_sock.settimeout(self.TIMEOUT_VALUE_SECONDS)
        self._recv_sock.bind((ip_addr, port))
        # When we send responses, we want to send them from this particular
        # interface.  The easiest way to do this is bind a socket directly to
        # the IP for the interface.  We're going to ignore messages sent to this
        # socket.
        self._send_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self._send_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self._send_sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_TTL,
                                   struct.pack('b', 1))
        self._send_sock.bind((self._interface_ip, port))
 
 
    def run_once(self):
        """Receive pending frames if available, return after timeout otw."""
        if self._recv_sock is None:
            raise error.TestError('Must listen() on socket before recv\'ing.')
        BUFFER_SIZE_BYTES = 2048
        try:
            data, sender_addr = self._recv_sock.recvfrom(BUFFER_SIZE_BYTES)
        except socket.timeout:
            return
        if len(sender_addr) != 2:
            logging.error('Unexpected address: %r', sender_addr)
        self._recv_callback(data, *sender_addr)
 
 
    def send(self, data, ip_addr, port):
        """Send |data| to an IPv4 address.
 
        @param data: string of raw bytes to send.
        @param ip_addr: string like '239.192.1.100'.
        @param port: int like 50000.
 
        """
        self._send_sock.sendto(data, (ip_addr, port))