[winswitch] Reworking Encryption in Xpra

Michael Vrable mvrable at google.com
Wed Oct 31 05:50:56 GMT 2012


On Tue, Oct 30, 2012 at 10:46:30PM -0700, Michael Vrable wrote:
> Attached is a first patch (still needs to be tested) at adding better
> transport-layer encryption to Xpra--it adds message authentication to each of
> the packets to prevent any tampering of the data stream.  Please don't commit
> it, as it isn't ready for that yet.

Does the mailing list strip attachments?  I'm not sure it went through, so here
it is again inline.

--Michael Vrable


commit 1f7e729bbb1e641506ca36ee82bd78c6ff009595
Author: Michael Vrable <mvrable at google.com>

Proof-of-concept implementation of a more secure transport layer for Xpra.

This uses AES encryption of data packets (in CTR mode to avoid the need
for padding), and a truncated HMAC-SHA-256 to provide authentication of
the data stream.

This assumes that both sides have run some type of key-agreement
protocol to establish a shared session secret.  I'm working on the key
exchange part in a separate patch which will follow.

This code isn't yet tested, but should give a basic idea.

diff --git a/src/xpra/protocol.py b/src/xpra/protocol.py
index 2fdec82..95da1ad 100644
--- a/src/xpra/protocol.py
+++ b/src/xpra/protocol.py
@@ -11,12 +11,16 @@
 from wimpiggy.gobject_compat import import_gobject
 gobject = import_gobject()
 gobject.threads_init()
+import hashlib
+import hmac
 import sys
 import socket # for socket.error
 import zlib
 import struct
 import time
 import os
+from Crypto.Cipher import AES
+from Crypto.Util import Counter
 
 NOYIELD = os.environ.get("XPRA_YIELD") is None
 
@@ -69,6 +73,168 @@ def zlib_compress(datatype, data, level=5):
     return ZLibCompressed(datatype, cdata, level)
 
 
+class CryptoError(ValueError):
+    """Error raised when decryption fails for any reason."""
+    pass
+
+
+class TransportCrypto:
+    """Interface for transport-level encryption layers.
+
+    This class defines the methods that must be implemented to define a
+    transport-layer encryption/authentication method.  The crypto layer is
+    generally set up after both parties have performed mutual authentication
+    and established a shared secret value, which is used to key the
+    encryption/MAC primitives.
+
+    This class does not provide an implementation and so should not be
+    instantiated directly.  Use one of NullCrypto or AESCrypto.
+    """
+
+    def name(self):
+        """Returns the name of the current crypto layer."""
+        raise NotImplementedError
+
+    def overhead_bytes(self, packet_len):
+        """Returns the number of bytes added to a packet of the given size.
+
+        This allows a crypto transport to add extra data for an IV, padding, or
+        MAC.
+        """
+        raise NotImplementedError
+
+    def encrypt(self, header, payload):
+        """Encrypt/MAC the specified data payload.
+
+        Returns the modified payload, which will be of length
+            len(payload) + overhead_bytes(len(payload))
+        The encrypted payload does not include the header, but the computed MAC
+        may cover the values in the header.
+        """
+        raise NotImplementedError
+
+    def decrypt(self, header, payload):
+        """Decrypt the specified data payload.
+
+        Returns the decrypted payload data, or raises a CryptoError exception
+        on any failures (bad MAC, incorrect padding if padding is used, etc.).
+        """
+        raise NotImplementedError
+
+
+class NullCrypto(TransportCrypto):
+    """A transparent transport which performs no encryption/authentication."""
+
+    def name(self):
+        return "null"
+
+    def overhead_bytes(self, packet_len):
+        return 0
+
+    def encrypt(self, header, payload):
+        return payload
+
+    def decrypt(self, header, payload):
+        return payload
+
+
+class AESCrypto(TransportCrypto):
+    """A crypto layer using AES256 in CTR mode and HMAC-SHA-256.
+
+    For each communication direction, two keys are derived from the shared
+    session secret: one encryption key and one message authentication key.
+    Data payloads are encrypted with AES in CTR mode, starting with a counter
+    value all zeroes.  No padding is needed.
+
+    A message authentication code (MAC) is appended to each packet; the MAC is
+    computed using HMAC-SHA-256 (keyed with the authentication key).  The
+    digest is computed over the concatentation of a 64-bit, big-endian packet
+    counter (to prevent packet replay/reorder attacks), the unencrypted packet
+    header, and the encrypted packet data.  The hash is truncated to mac_bytes
+    in length (default: 12 bytes = 96 bits) then appended after the packet
+    data.  The length field encoded in the packet header is the length of the
+    data payload, not including the header and MAC.
+    """
+
+    def __init__(self, session_secret, context, mac_bytes=12):
+        """Initializes the crypto layer for one direction of a transport.
+
+        Args:
+            session_secret: A secret value negotiated by both sides of the
+                connection.  The value of session_secret must never be re-used
+                in a different connection, or security may suffer.  This may be
+                any string value.
+            context: A string used to distinguish multiple related
+                instantiations of AESCrypto.  For example, a client and server
+                may compute a single shared session_secret, and use a context
+                of "server" for data sent by the server and "client" for data
+                sent by the client.  This ensures that separate keys are used
+                for each data direction.
+            mac_bytes: Number of bytes to include in the message authentication
+                code for each packet.  The MAC is a truncated HMAC-SHA-256;
+                mac_bytes can be up to 32 but smaller values reduce overhead at
+                the risk of allowing undetected errors if mac_bytes is too
+                small.
+        """
+        self._session_secret = session_secret
+        self._context = context
+
+        # Derived keys.  There are two:
+        #   - A 256-bit AES key for encryption
+        #   - A key used for HMAC authentication
+        def derive_key(subtype):
+            return hmac.new(session_secret, "%s-%s" % (context, subtype),
+                            digest_mod=hashlib.sha256).digest()
+        key_enc = derive_key("aes")
+        key_mac = derive_key("mac")
+
+        # Start CTR mode counting from zero.  This is safe as long as every
+        # session (and direction) uses a unique encryption key.
+        self._cipher = AES.new(key_enc, mode=AES.MODE_CTR,
+                               counter=Counter.new(128))
+
+        self._hmac_key = key_mac
+        self._mac_bytes = mac_bytes
+
+        # The packet counter.  This is incremented for each packet processed.
+        # The packet counter is included in the MAC for each packet (to prevent
+        # replay/reordering attacks), but is not explicitly added to the output
+        # to reduce overhead.
+        self._packet_counter = 0
+
+    def name(self):
+        return "aes256-ctr/hmac256-%d" % (self._mac_bytes * 8)
+
+    def overhead_bytes(self, packet_len):
+        # No padding is needed for CTR mode, so the only overhead is the MAC
+        # overhead, independent of packet size.
+        return self._mac_bytes
+
+    def encrypt(self, header, payload):
+        payload = self._cipher.encrypt(payload)
+        mac = hmac.new(self._hmac_key, digest_mod=hashlib.sha256)
+        mac.update(struct.pack("!Q", self._packet_counter))
+        self._packet_counter += 1
+        mac.update(header)
+        mac.update(payload)
+        mac_value = mac.digest()[0:self._mac_bytes]
+        return payload + mac_value
+
+    def decrypt(self, header, payload):
+        if len(payload) < self._mac_bytes:
+            raise CryptoError("Bad decryption")
+        payload_data = payload[:-self._mac_bytes]
+        payload_mac = payload[-self._mac_bytes:]
+        mac = hmac.new(self._hmac_key, digest_mod=hashlib.sha256)
+        mac.update(struct.pack("!Q", self._packet_counter))
+        self._packet_counter += 1
+        mac.update(header)
+        mac.update(payload_data)
+        if mac.digest()[0:self._mac_bytes] != payload_mac:
+            raise CryptoError("Bad decryption")
+        return self._cipher.decrypt(payload_data)
+
+
 class Protocol(object):
     CONNECTION_LOST = "connection-lost"
     GIBBERISH = "gibberish"
@@ -97,12 +263,8 @@ class Protocol(object):
         self._encoder = self.bencode
         self._decompressor = zlib.decompressobj()
         self._compression_level = 0
-        self.cipher_in = None
-        self.cipher_in_name = None
-        self.cipher_in_block_size = 0
-        self.cipher_out = None
-        self.cipher_out_name = None
-        self.cipher_out_block_size = 0
+        self.cipher_in = NullCrypto()
+        self.cipher_out = NullCrypto()
         def make_daemon_thread(target, name):
             daemon_thread = Thread(target=target, name=name)
             daemon_thread.setDaemon(True)
@@ -112,33 +274,15 @@ class Protocol(object):
         self._read_thread = make_daemon_thread(self._read_thread_loop, "read_loop")
         self._read_parser_thread = make_daemon_thread(self._read_parse_thread_loop, "read_parse_loop")
 
-    def get_cipher(self, ciphername, iv, password, key_salt, iterations):
-        log("get_cipher_in(%s, %s, %s, %s, %s)", ciphername, iv, password, key_salt, iterations)
-        if not ciphername:
-            return None, 0
-        assert iterations>=100
-        assert ciphername=="AES"
-        assert password and iv
-        from Crypto.Cipher import AES
-        from Crypto.Protocol.KDF import PBKDF2
-        #stretch the password:
-        block_size = 32         #fixme: can we derive this?
-        secret = PBKDF2(password, key_salt, dkLen=block_size, count=iterations)
-        #secret = (password+password+password+password+password+password+password+password)[:32]
-        log("get_cipher(%s, %s, %s) secret=%s, block_size=%s", ciphername, iv, password, secret.encode('hex'), block_size)
-        return AES.new(secret, AES.MODE_CBC, iv), block_size
-
-    def set_cipher_in(self, ciphername, iv, password, key_salt, iterations):
-        if self.cipher_in_name!=ciphername:
-            log.info("receiving data using %s encryption", ciphername)
-            self.cipher_in_name = ciphername
-        self.cipher_in, self.cipher_in_block_size = self.get_cipher(ciphername, iv, password, key_salt, iterations)
-
-    def set_cipher_out(self, ciphername, iv, password, key_salt, iterations):
-        if self.cipher_out_name!=ciphername:
-            log.info("sending data using %s encryption", ciphername)
-            self.cipher_out_name = ciphername
-        self.cipher_out, self.cipher_out_block_size = self.get_cipher(ciphername, iv, password, key_salt, iterations)
+    def set_cipher(self, direction, ciphername, session_secret, context):
+        cipher = AESCrypto(session_secret, context)
+        log.info("setting encryption from %s to %s", context, cipher.name())
+        if direction == "in":
+            self.cipher_in = cipher
+        elif direction == "out":
+            self.cipher_out = cipher
+        else:
+            raise ValueError("Unknown cipher direction: " + direction)
 
     def __str__(self):
         ti = ["%s:%s" % (x.name, x.is_alive()) for x in self.get_threads()]
@@ -289,29 +433,12 @@ class Protocol(object):
                 #fire the end_send callback when the last packet (index==0) makes it out:
                 if index==0:
                     ecb = end_send_cb
-                if self.cipher_out:
-                    proto_flags |= Protocol.FLAGS_CIPHER
-                    #note: since we are padding: l!=len(data)
-                    padding = (self.cipher_out_block_size - len(data) % self.cipher_out_block_size) * " "
-                    if len(padding)==0:
-                        padded = data
-                    else:
-                        padded = data+padding
-                    actual_size = payload_size + len(padding)
-                    assert len(padded)==actual_size
-                    data = self.cipher_out.encrypt(padded)
-                    assert len(data)==actual_size
-                    log("sending %s bytes encrypted with %s padding", payload_size, len(padding))
-                if actual_size<16384:
-                    #'p' + protocol-flags + compression_level + packet_index + data_size
-                    if type(data)==unicode:
-                        data = str(data)
-                    header_and_data = struct.pack('!BBBBL%ss' % actual_size, ord("P"), proto_flags, level, index, payload_size, data)
-                    self._write_queue.put((header_and_data, scb, ecb))
-                else:
-                    header = struct.pack('!BBBBL', ord("P"), proto_flags, level, index, payload_size)
-                    self._write_queue.put((header, scb, None))
-                    self._write_queue.put((data, None, ecb))
+                header = struct.pack('!BBBBL',
+                                     ord("P"), proto_flags, level,
+                                     index, payload_size)
+                data = self.cipher_out.encrypt(header, data)
+                self._write_queue.put((header, scb, None))
+                self._write_queue.put((data, None, ecb))
                 counter += 1
         finally:
             self.output_packetcount += 1
@@ -437,20 +564,14 @@ class Protocol(object):
                         break   #packet still too small
                     #packet format: struct.pack('cBBBL', ...) - 8 bytes
                     try:
-                        _, protocol_flags, compression_level, packet_index, data_size = struct.unpack_from('!cBBBL', read_buffer)
+                        header = read_buffer[:8]
+                        _, protocol_flags, compression_level, packet_index, data_size = struct.unpack('!cBBBL', header)
                     except Exception, e:
                         raise Exception("invalid packet header: %s" % list(read_buffer[:8]), e)
                     read_buffer = read_buffer[8:]
                     bl = len(read_buffer)
-                    if protocol_flags & Protocol.FLAGS_CIPHER:
-                        assert self.cipher_in_block_size>0, "received cipher block but we don't have a cipher do decrypt it with"
-                        padding = (self.cipher_in_block_size - data_size % self.cipher_in_block_size) * " "
-                        payload_size = data_size + len(padding)
-                    else:
-                        #no cipher, no padding:
-                        padding = None
-                        payload_size = data_size
-                    assert payload_size>0
+                    payload_size = (data_size +
+                                    self.cipher_in.overhead_bytes(data_size))
 
                 if payload_size>self.max_packet_size:
                     #this packet is seemingly too big, but check again from the main UI thread
@@ -474,13 +595,11 @@ class Protocol(object):
                     raw_string = read_buffer[:payload_size]
                     read_buffer = read_buffer[payload_size:]
                 #decrypt if needed:
-                data = raw_string
-                if self.cipher_in and protocol_flags & Protocol.FLAGS_CIPHER:
-                    log("received %s encrypted bytes with %s padding", payload_size, len(padding))
-                    data = self.cipher_in.decrypt(raw_string)
-                    if padding:
-                        assert data.endswith(padding), "decryption failed: string does not end with '%s': %s (%s) -> %s (%s)" % (padding, list(bytearray(raw_string)), type(raw_string), list(bytearray(data)), type(data))
-                        data = data[:-len(padding)]
+                try:
+                    data = self.cipher_in.decrypt(header, raw_string)
+                    if len(data) != len(data_size): raise CryptoError()
+                except CryptoError:
+                    return self._call_connection_lost("Decryption failed: %s" % repr_ellipsized(data))
                 #uncompress if needed:
                 if compression_level>0:
                     if self.chunked_compression:
@@ -490,9 +609,6 @@ class Protocol(object):
                 if sys.version>='3':
                     data = data.decode("latin1")
 
-                if self.cipher_in and not (protocol_flags & Protocol.FLAGS_CIPHER):
-                    return self._call_connection_lost("unencrypted packet dropped: %s" % repr_ellipsized(data))
-
                 if self._closed:
                     return
                 if packet_index>0:



More information about the shifter-users mailing list