# Written by Bram Cohen
# Modified by Cameron Dale
# see LICENSE.txt for license information
#
# $Id: NatCheck.py 279 2007-08-22 20:09:43Z camrdale-guest $

"""Check if a peer is unreachable behind a NAT.

@type logger: C{logging.Logger}
@var logger: the logger to send all log messages to for this module
@type CHECK_PEER_ID_ENCRYPTED: C{boolean}
@var CHECK_PEER_ID_ENCRYPTED: whether to check if connecting peers are encrypted

"""

from cStringIO import StringIO
from socket import error as socketerror
from DebTorrent.BTcrypto import Crypto, CRYPTO_OK
from DebTorrent.__init__ import protocol_name
from binascii import b2a_hex
import struct
import logging

logger = logging.getLogger('DebTorrent.BT1.NatCheck')

CHECK_PEER_ID_ENCRYPTED = True

# header, reserved, download id, my id, [length, message]

class NatCheck:
    """Check if a peer is unreachable behind a NAT.
    
    @type resultfunc: C{method}
    @ivar resultfunc: the method to call with the result when complete
    @type downloadid: C{string}
    @ivar downloadid: the info hash of the torrent to use
    @type peerid: C{string}
    @ivar peerid: the peer ID of the peer being checked
    @type ip: C{string}
    @ivar ip: the IP of the peer being checked
    @type port: C{int}
    @ivar port: the port to connect to the peer on
    @type encrypted: C{boolean}
    @ivar encrypted: whether to use an encrypted connection
    @type closed: C{boolean}
    @ivar closed: whether the connection has been closed
    @type buffer: C{string}
    @ivar buffer: the buffer of received data from the connection
    @type read: C{method}
    @ivar read: the method to use to read from the connection
    @type write: C{method}
    @ivar write: the method to use to write to the connection
    @type connection: L{DebTorrent.SocketHandler.SingleSocket}
    @ivar connection: the connection to the peer
    @type _dc: C{boolean}
    @ivar _dc: whether encrypted connections have been disabled
    @type encrypter: L{DebTorrent.BTcrypto.Crypto}
    @ivar encrypter: the encrypter to use for the connection
    @type next_len: C{int}
    @ivar next_len: the next amount of data to read from the connection
    @type next_func: C{method}
    @ivar next_func: the next method to use to process incoming data on the 
        connection
    @type _max_search: C{int}
    @ivar _max_search: the number of remaining bytes to search for the pattern
    @type cryptmode: C{int}
    @ivar cryptmode: the type of encryption being used
    
    """
    
    def __init__(self, resultfunc, downloadid, peerid, ip, port, rawserver,
                 encrypted = False):
        """Initialize the instance and start a connection to the peer.
        
        @type resultfunc: C{method}
        @param resultfunc: the method to call with the result when complete
        @type downloadid: C{string}
        @param downloadid: the info hash of the torrent to use
        @type peerid: C{string}
        @param peerid: the peer ID of the peer being checked
        @type ip: C{string}
        @param ip: the IP of the peer being checked
        @type port: C{int}
        @param port: the port to connect to the peer on
        @type rawserver: L{DebTorrent.RawServer.RawServer}
        @param rawserver: the server instance to use
        @type encrypted: C{boolean}
        @param encrypted: whether to use an encrypted connection
            (optional, defaults to False)
        
        """
        
        self.resultfunc = resultfunc
        self.downloadid = downloadid
        self.peerid = peerid
        self.readable_id = b2a_hex(peerid)
        logger.info('Starting a NAT check for '+self.readable_id)
        self.ip = ip
        self.port = port
        self.encrypted = encrypted
        self.closed = False
        self.buffer = ''
        self.read = self._read
        self.write = self._write
        try:
            self.connection = rawserver.start_connection((ip, port), self)
            if encrypted:
                logger.info('Initiating an encrypted connection to '+self.readable_id)
                self._dc = not(CRYPTO_OK and CHECK_PEER_ID_ENCRYPTED)
                self.encrypter = Crypto(True, disable_crypto = self._dc)
                self.write(self.encrypter.pubkey+self.encrypter.padding())
            else:
                logger.info('Initiating an unencrypted connection to '+self.readable_id)
                self.encrypter = None
                self.write(chr(len(protocol_name)) + protocol_name +
                    (chr(0) * 8) + downloadid)
        except (socketerror, IOError):
            logger.exception('Could not initiate a connection to %s:%d for %s', 
                             ip, port, self.readable_id)
            self.answer(False)
        self.next_len, self.next_func = 1+len(protocol_name), self.read_header

    def answer(self, result):
        """Close the connection and return the result.
        
        @type result: C{boolean}
        @param result: whether the peer is connectable
        
        """
        
        self.closed = True
        try:
            self.connection.close()
        except AttributeError:
            pass
        logger.info('Result was '+str(result)+' for: '+self.readable_id)
        self.resultfunc(result, self.downloadid, self.peerid, self.ip, self.port)

    def _read_header(self, s):
        """Read the protocol header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if s == chr(len(protocol_name))+protocol_name:
            return 8, self.read_options
        return None

    def read_header(self, s):
        """Read the possibly encrypted protocol header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if self._read_header(s):
            if self.encrypted:
                logger.info('Dropped the connection as it was unencrypted: '+self.readable_id)
                return None
            return 8, self.read_options
        if not self.encrypted:
            logger.info('Got a bad protocol name: '+self.readable_id)
            return None
        self._write_buffer(s)
        return self.encrypter.keylength, self.read_crypto_header

    ################## ENCRYPTION SUPPORT ######################

    def _start_crypto(self):
        """Setup the connection for encrypted communication."""
        self.encrypter.setrawaccess(self._read,self._write)
        self.write = self.encrypter.write
        self.read = self.encrypter.read
        if self.buffer:
            self.buffer = self.encrypter.decrypt(self.buffer)

    def read_crypto_header(self, s):
        """Read the encryption key.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        self.encrypter.received_key(s)
        self.encrypter.set_skey(self.downloadid)
        cryptmode = '\x00\x00\x00\x02'    # full stream encryption
        padc = self.encrypter.padding()
        self.write( self.encrypter.block3a
                  + self.encrypter.block3b
                  + self.encrypter.encrypt(
                        ('\x00'*8)            # VC
                      + cryptmode             # acceptable crypto modes
                      + struct.pack('>h', len(padc))
                      + padc                  # PadC
                      + '\x00\x00' ) )        # no initial payload data
        self._max_search = 520
        return 1, self.read_crypto_block4a

    def _search_for_pattern(self, s, pat):
        """Search for a pattern in the encrypted protocol header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @type pat: C{string}
        @param pat: the pattern to find
        @rtype: C{boolean}
        @return: whether the pattern was found
        
        """
        
        p = s.find(pat)
        if p < 0:
            if len(s) >= len(pat):
                self._max_search -= len(s)+1-len(pat)
            if self._max_search < 0:
                self.close()
                return False
            self._write_buffer(s[1-len(pat):])
            return False
        self._write_buffer(s[p+len(pat):])
        return True

    ### OUTGOING CONNECTION ###

    def read_crypto_block4a(self, s):
        """Read the encrypted protocol header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if not self._search_for_pattern(s,self.encrypter.VC_pattern()):
            return -1, self.read_crypto_block4a     # wait for more data
        if self._dc:                        # can't or won't go any further
            self.answer(True)
            return None
        self._start_crypto()
        return 6, self.read_crypto_block4b

    def read_crypto_block4b(self, s):
        """Read the encrypted protocol mode.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        self.cryptmode = struct.unpack('>i',s[:4])[0] % 4
        if self.cryptmode != 2:
            logger.info('Dropped the encrypted connection due to an unknown crypt mode: '+self.readable_id)
            return None                     # unknown encryption
        padlen = (ord(s[4])<<8)+ord(s[5])
        if padlen > 512:
            logger.info('Dropped the encrypted connection due to bad padding: '+self.readable_id)
            return None
        if padlen:
            return padlen, self.read_crypto_pad4
        return self.read_crypto_block4done()

    def read_crypto_pad4(self, s):
        """Read the encrypted protocol padding.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        # discard data
        return self.read_crypto_block4done()

    def read_crypto_block4done(self):
        """Finish with the encrypted header.
        
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if self.cryptmode == 1:     # only handshake encryption
            if not self.buffer:  # oops; check for exceptions to this
                logger.info('Dropped the encrypted connection due to a lack of buffer: '+self.readable_id)
                return None
            self._end_crypto()
        self.write(chr(len(protocol_name)) + protocol_name + 
            option_pattern + self.Encoder.download_id)
        return 1+len(protocol_name), self.read_encrypted_header

    ### START PROTOCOL OVER ENCRYPTED CONNECTION ###

    def read_encrypted_header(self, s):
        """Read the regular protocol name header from the encrypted stream.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        return self._read_header(s)

    ################################################

    def read_options(self, s):
        """Read the options from the header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        return 20, self.read_download_id

    def read_download_id(self, s):
        """Verify the torrent infohash from the header.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: C{int}, C{method}
        @return: the next amount of data to read and the method to call with
            it, or None if there is no next method to call
        
        """
        
        if s != self.downloadid:
            logger.warning('Torrent info hash does not match: '+self.readable_id)
            return None
        return 20, self.read_peer_id

    def read_peer_id(self, s):
        """Verify the peer's ID and return the answer.
        
        @type s: C{string}
        @param s: the incoming data from the connection
        @rtype: None
        @return: None
        
        """
        
        if s != self.peerid:
            logger.info('Peer ID does not match: '+self.readable_id)
            return None
        self.answer(True)
        return None

    def _write(self, message):
        """Write a raw message out on the connection.
        
        @type message: C{string}
        @param message: the raw data to write to the connection
        
        """
        
        if not self.closed:
            self.connection.write(message)

    def data_came_in(self, connection, s):
        """Process the incoming data on the connection.
        
        @type connection: L{DebTorrent.SocketHandler.SingleSocket}
        @param connection: the connection the data came in on (not used)
        @type s: C{string}
        @param s: the incoming data from the connection
        
        """
        
        self.read(s)

    def _write_buffer(self, s):
        """Write data back onto the buffer.
        
        @type s: C{string}
        @param s: the data to rebuffer
        
        """
        
        self.buffer = s+self.buffer

    def _read(self, s):
        """Process the data that came in.
        
        @type s: C{string}
        @param s: the (unencrypted) incoming data from the connection
        
        """
        
        self.buffer += s
        while True:
            if self.closed:
                return
            # self.next_len = # of characters function expects
            # or 0 = all characters in the buffer
            # or -1 = wait for next read, then all characters in the buffer
            # not compatible w/ keepalives, switch out after all negotiation complete
            if self.next_len <= 0:
                m = self.buffer
                self.buffer = ''
            elif len(self.buffer) >= self.next_len:
                m = self.buffer[:self.next_len]
                self.buffer = self.buffer[self.next_len:]
            else:
                return
            try:
                x = self.next_func(m)
            except:
                logger.exception('Dropped connection due to exception: '+self.readable_id)
                if not self.closed:
                    self.answer(False)
                return
            if x is None:
                if not self.closed:
                    self.answer(False)
                return
            self.next_len, self.next_func = x
            if self.next_len < 0:  # already checked buffer
                return             # wait for additional data

    def connection_lost(self, connection):
        """Close the connection and return the failure.
        
        @type connection: L{DebTorrent.SocketHandler.SingleSocket}
        @param connection: the connection that was lost (not used)
        
        """
        
        if not self.closed:
            logger.warning('Connection was dropped externally: '+self.readable_id)
            self.closed = True
            self.resultfunc(False, self.downloadid, self.peerid, self.ip, self.port)

    def connection_flushed(self, connection):
        """Do nothing.
        
        @type connection: L{DebTorrent.SocketHandler.SingleSocket}
        @param connection: the connection that was flushed (not used)
        
        """
        
        pass
