[winswitch] Reworking Encryption in Xpra
Michael Vrable
mvrable at google.com
Wed Oct 31 05:50:56 GMT 2012
On Tue, Oct 30, 2012 at 10:46:30PM -0700, Michael Vrable wrote:
> Attached is a first patch (still needs to be tested) at adding better
> transport-layer encryption to Xpra--it adds message authentication to each of
> the packets to prevent any tampering of the data stream. Please don't commit
> it, as it isn't ready for that yet.
Does the mailing list strip attachments? I'm not sure it went through, so here
it is again inline.
--Michael Vrable
commit 1f7e729bbb1e641506ca36ee82bd78c6ff009595
Author: Michael Vrable <mvrable at google.com>
Proof-of-concept implementation of a more secure transport layer for Xpra.
This uses AES encryption of data packets (in CTR mode to avoid the need
for padding), and a truncated HMAC-SHA-256 to provide authentication of
the data stream.
This assumes that both sides have run some type of key-agreement
protocol to establish a shared session secret. I'm working on the key
exchange part in a separate patch which will follow.
This code isn't yet tested, but should give a basic idea.
diff --git a/src/xpra/protocol.py b/src/xpra/protocol.py
index 2fdec82..95da1ad 100644
--- a/src/xpra/protocol.py
+++ b/src/xpra/protocol.py
@@ -11,12 +11,16 @@
from wimpiggy.gobject_compat import import_gobject
gobject = import_gobject()
gobject.threads_init()
+import hashlib
+import hmac
import sys
import socket # for socket.error
import zlib
import struct
import time
import os
+from Crypto.Cipher import AES
+from Crypto.Util import Counter
NOYIELD = os.environ.get("XPRA_YIELD") is None
@@ -69,6 +73,168 @@ def zlib_compress(datatype, data, level=5):
return ZLibCompressed(datatype, cdata, level)
+class CryptoError(ValueError):
+ """Error raised when decryption fails for any reason."""
+ pass
+
+
+class TransportCrypto:
+ """Interface for transport-level encryption layers.
+
+ This class defines the methods that must be implemented to define a
+ transport-layer encryption/authentication method. The crypto layer is
+ generally set up after both parties have performed mutual authentication
+ and established a shared secret value, which is used to key the
+ encryption/MAC primitives.
+
+ This class does not provide an implementation and so should not be
+ instantiated directly. Use one of NullCrypto or AESCrypto.
+ """
+
+ def name(self):
+ """Returns the name of the current crypto layer."""
+ raise NotImplementedError
+
+ def overhead_bytes(self, packet_len):
+ """Returns the number of bytes added to a packet of the given size.
+
+ This allows a crypto transport to add extra data for an IV, padding, or
+ MAC.
+ """
+ raise NotImplementedError
+
+ def encrypt(self, header, payload):
+ """Encrypt/MAC the specified data payload.
+
+ Returns the modified payload, which will be of length
+ len(payload) + overhead_bytes(len(payload))
+ The encrypted payload does not include the header, but the computed MAC
+ may cover the values in the header.
+ """
+ raise NotImplementedError
+
+ def decrypt(self, header, payload):
+ """Decrypt the specified data payload.
+
+ Returns the decrypted payload data, or raises a CryptoError exception
+ on any failures (bad MAC, incorrect padding if padding is used, etc.).
+ """
+ raise NotImplementedError
+
+
+class NullCrypto(TransportCrypto):
+ """A transparent transport which performs no encryption/authentication."""
+
+ def name(self):
+ return "null"
+
+ def overhead_bytes(self, packet_len):
+ return 0
+
+ def encrypt(self, header, payload):
+ return payload
+
+ def decrypt(self, header, payload):
+ return payload
+
+
+class AESCrypto(TransportCrypto):
+ """A crypto layer using AES256 in CTR mode and HMAC-SHA-256.
+
+ For each communication direction, two keys are derived from the shared
+ session secret: one encryption key and one message authentication key.
+ Data payloads are encrypted with AES in CTR mode, starting with a counter
+ value all zeroes. No padding is needed.
+
+ A message authentication code (MAC) is appended to each packet; the MAC is
+ computed using HMAC-SHA-256 (keyed with the authentication key). The
+ digest is computed over the concatentation of a 64-bit, big-endian packet
+ counter (to prevent packet replay/reorder attacks), the unencrypted packet
+ header, and the encrypted packet data. The hash is truncated to mac_bytes
+ in length (default: 12 bytes = 96 bits) then appended after the packet
+ data. The length field encoded in the packet header is the length of the
+ data payload, not including the header and MAC.
+ """
+
+ def __init__(self, session_secret, context, mac_bytes=12):
+ """Initializes the crypto layer for one direction of a transport.
+
+ Args:
+ session_secret: A secret value negotiated by both sides of the
+ connection. The value of session_secret must never be re-used
+ in a different connection, or security may suffer. This may be
+ any string value.
+ context: A string used to distinguish multiple related
+ instantiations of AESCrypto. For example, a client and server
+ may compute a single shared session_secret, and use a context
+ of "server" for data sent by the server and "client" for data
+ sent by the client. This ensures that separate keys are used
+ for each data direction.
+ mac_bytes: Number of bytes to include in the message authentication
+ code for each packet. The MAC is a truncated HMAC-SHA-256;
+ mac_bytes can be up to 32 but smaller values reduce overhead at
+ the risk of allowing undetected errors if mac_bytes is too
+ small.
+ """
+ self._session_secret = session_secret
+ self._context = context
+
+ # Derived keys. There are two:
+ # - A 256-bit AES key for encryption
+ # - A key used for HMAC authentication
+ def derive_key(subtype):
+ return hmac.new(session_secret, "%s-%s" % (context, subtype),
+ digest_mod=hashlib.sha256).digest()
+ key_enc = derive_key("aes")
+ key_mac = derive_key("mac")
+
+ # Start CTR mode counting from zero. This is safe as long as every
+ # session (and direction) uses a unique encryption key.
+ self._cipher = AES.new(key_enc, mode=AES.MODE_CTR,
+ counter=Counter.new(128))
+
+ self._hmac_key = key_mac
+ self._mac_bytes = mac_bytes
+
+ # The packet counter. This is incremented for each packet processed.
+ # The packet counter is included in the MAC for each packet (to prevent
+ # replay/reordering attacks), but is not explicitly added to the output
+ # to reduce overhead.
+ self._packet_counter = 0
+
+ def name(self):
+ return "aes256-ctr/hmac256-%d" % (self._mac_bytes * 8)
+
+ def overhead_bytes(self, packet_len):
+ # No padding is needed for CTR mode, so the only overhead is the MAC
+ # overhead, independent of packet size.
+ return self._mac_bytes
+
+ def encrypt(self, header, payload):
+ payload = self._cipher.encrypt(payload)
+ mac = hmac.new(self._hmac_key, digest_mod=hashlib.sha256)
+ mac.update(struct.pack("!Q", self._packet_counter))
+ self._packet_counter += 1
+ mac.update(header)
+ mac.update(payload)
+ mac_value = mac.digest()[0:self._mac_bytes]
+ return payload + mac_value
+
+ def decrypt(self, header, payload):
+ if len(payload) < self._mac_bytes:
+ raise CryptoError("Bad decryption")
+ payload_data = payload[:-self._mac_bytes]
+ payload_mac = payload[-self._mac_bytes:]
+ mac = hmac.new(self._hmac_key, digest_mod=hashlib.sha256)
+ mac.update(struct.pack("!Q", self._packet_counter))
+ self._packet_counter += 1
+ mac.update(header)
+ mac.update(payload_data)
+ if mac.digest()[0:self._mac_bytes] != payload_mac:
+ raise CryptoError("Bad decryption")
+ return self._cipher.decrypt(payload_data)
+
+
class Protocol(object):
CONNECTION_LOST = "connection-lost"
GIBBERISH = "gibberish"
@@ -97,12 +263,8 @@ class Protocol(object):
self._encoder = self.bencode
self._decompressor = zlib.decompressobj()
self._compression_level = 0
- self.cipher_in = None
- self.cipher_in_name = None
- self.cipher_in_block_size = 0
- self.cipher_out = None
- self.cipher_out_name = None
- self.cipher_out_block_size = 0
+ self.cipher_in = NullCrypto()
+ self.cipher_out = NullCrypto()
def make_daemon_thread(target, name):
daemon_thread = Thread(target=target, name=name)
daemon_thread.setDaemon(True)
@@ -112,33 +274,15 @@ class Protocol(object):
self._read_thread = make_daemon_thread(self._read_thread_loop, "read_loop")
self._read_parser_thread = make_daemon_thread(self._read_parse_thread_loop, "read_parse_loop")
- def get_cipher(self, ciphername, iv, password, key_salt, iterations):
- log("get_cipher_in(%s, %s, %s, %s, %s)", ciphername, iv, password, key_salt, iterations)
- if not ciphername:
- return None, 0
- assert iterations>=100
- assert ciphername=="AES"
- assert password and iv
- from Crypto.Cipher import AES
- from Crypto.Protocol.KDF import PBKDF2
- #stretch the password:
- block_size = 32 #fixme: can we derive this?
- secret = PBKDF2(password, key_salt, dkLen=block_size, count=iterations)
- #secret = (password+password+password+password+password+password+password+password)[:32]
- log("get_cipher(%s, %s, %s) secret=%s, block_size=%s", ciphername, iv, password, secret.encode('hex'), block_size)
- return AES.new(secret, AES.MODE_CBC, iv), block_size
-
- def set_cipher_in(self, ciphername, iv, password, key_salt, iterations):
- if self.cipher_in_name!=ciphername:
- log.info("receiving data using %s encryption", ciphername)
- self.cipher_in_name = ciphername
- self.cipher_in, self.cipher_in_block_size = self.get_cipher(ciphername, iv, password, key_salt, iterations)
-
- def set_cipher_out(self, ciphername, iv, password, key_salt, iterations):
- if self.cipher_out_name!=ciphername:
- log.info("sending data using %s encryption", ciphername)
- self.cipher_out_name = ciphername
- self.cipher_out, self.cipher_out_block_size = self.get_cipher(ciphername, iv, password, key_salt, iterations)
+ def set_cipher(self, direction, ciphername, session_secret, context):
+ cipher = AESCrypto(session_secret, context)
+ log.info("setting encryption from %s to %s", context, cipher.name())
+ if direction == "in":
+ self.cipher_in = cipher
+ elif direction == "out":
+ self.cipher_out = cipher
+ else:
+ raise ValueError("Unknown cipher direction: " + direction)
def __str__(self):
ti = ["%s:%s" % (x.name, x.is_alive()) for x in self.get_threads()]
@@ -289,29 +433,12 @@ class Protocol(object):
#fire the end_send callback when the last packet (index==0) makes it out:
if index==0:
ecb = end_send_cb
- if self.cipher_out:
- proto_flags |= Protocol.FLAGS_CIPHER
- #note: since we are padding: l!=len(data)
- padding = (self.cipher_out_block_size - len(data) % self.cipher_out_block_size) * " "
- if len(padding)==0:
- padded = data
- else:
- padded = data+padding
- actual_size = payload_size + len(padding)
- assert len(padded)==actual_size
- data = self.cipher_out.encrypt(padded)
- assert len(data)==actual_size
- log("sending %s bytes encrypted with %s padding", payload_size, len(padding))
- if actual_size<16384:
- #'p' + protocol-flags + compression_level + packet_index + data_size
- if type(data)==unicode:
- data = str(data)
- header_and_data = struct.pack('!BBBBL%ss' % actual_size, ord("P"), proto_flags, level, index, payload_size, data)
- self._write_queue.put((header_and_data, scb, ecb))
- else:
- header = struct.pack('!BBBBL', ord("P"), proto_flags, level, index, payload_size)
- self._write_queue.put((header, scb, None))
- self._write_queue.put((data, None, ecb))
+ header = struct.pack('!BBBBL',
+ ord("P"), proto_flags, level,
+ index, payload_size)
+ data = self.cipher_out.encrypt(header, data)
+ self._write_queue.put((header, scb, None))
+ self._write_queue.put((data, None, ecb))
counter += 1
finally:
self.output_packetcount += 1
@@ -437,20 +564,14 @@ class Protocol(object):
break #packet still too small
#packet format: struct.pack('cBBBL', ...) - 8 bytes
try:
- _, protocol_flags, compression_level, packet_index, data_size = struct.unpack_from('!cBBBL', read_buffer)
+ header = read_buffer[:8]
+ _, protocol_flags, compression_level, packet_index, data_size = struct.unpack('!cBBBL', header)
except Exception, e:
raise Exception("invalid packet header: %s" % list(read_buffer[:8]), e)
read_buffer = read_buffer[8:]
bl = len(read_buffer)
- if protocol_flags & Protocol.FLAGS_CIPHER:
- assert self.cipher_in_block_size>0, "received cipher block but we don't have a cipher do decrypt it with"
- padding = (self.cipher_in_block_size - data_size % self.cipher_in_block_size) * " "
- payload_size = data_size + len(padding)
- else:
- #no cipher, no padding:
- padding = None
- payload_size = data_size
- assert payload_size>0
+ payload_size = (data_size +
+ self.cipher_in.overhead_bytes(data_size))
if payload_size>self.max_packet_size:
#this packet is seemingly too big, but check again from the main UI thread
@@ -474,13 +595,11 @@ class Protocol(object):
raw_string = read_buffer[:payload_size]
read_buffer = read_buffer[payload_size:]
#decrypt if needed:
- data = raw_string
- if self.cipher_in and protocol_flags & Protocol.FLAGS_CIPHER:
- log("received %s encrypted bytes with %s padding", payload_size, len(padding))
- data = self.cipher_in.decrypt(raw_string)
- if padding:
- assert data.endswith(padding), "decryption failed: string does not end with '%s': %s (%s) -> %s (%s)" % (padding, list(bytearray(raw_string)), type(raw_string), list(bytearray(data)), type(data))
- data = data[:-len(padding)]
+ try:
+ data = self.cipher_in.decrypt(header, raw_string)
+ if len(data) != len(data_size): raise CryptoError()
+ except CryptoError:
+ return self._call_connection_lost("Decryption failed: %s" % repr_ellipsized(data))
#uncompress if needed:
if compression_level>0:
if self.chunked_compression:
@@ -490,9 +609,6 @@ class Protocol(object):
if sys.version>='3':
data = data.decode("latin1")
- if self.cipher_in and not (protocol_flags & Protocol.FLAGS_CIPHER):
- return self._call_connection_lost("unencrypted packet dropped: %s" % repr_ellipsized(data))
-
if self._closed:
return
if packet_index>0:
More information about the shifter-users
mailing list