# -*- coding: utf-8 -*-
# Elisa - Home multimedia server
# Copyright (C) 2006-2008 Fluendo Embedded S.L. (www.fluendo.com).
# All rights reserved.
#
# This file is available under one of two license agreements.
#
# This file is licensed under the GPL version 2.
# See "LICENSE.GPL" in the root of this distribution including a special
# exception to use Elisa with Fluendo's plugins.
#
# The GPL part of Elisa is also available under a commercial licensing
# agreement from Fluendo.
# See "LICENSE.Elisa" in the root directory of this distribution package
# for details on that license.
#
# Authors: Alessandro Decina <alessandro@fluendo.com>

from twisted.trial.unittest import TestCase, SkipTest
from twisted.internet import defer
from twisted.protocols.amp import AMP, Command, String, Integer, Float
from twisted.test.iosim import connectedServerAndClient

from elisa.plugins.amp.master import Master, SlaveProcessProtocol, StartError
from elisa.plugins.amp.protocol import Union, Ping, \
        MasterFactory, MasterProtocol, SlaveFactory, SlaveProtocol

import platform

# useful when debugging the protocol
#import sys
#from twisted.python import log
#log.startLogging(sys.stdout, setStdout=0)

class StubSlaveProcessProtocol(SlaveProcessProtocol):
    def __init__(self, *args):
        SlaveProcessProtocol.__init__(self, *args)
    def processEnded(self, reason):
        self.master.dead_processes.append((self, reason))
        SlaveProcessProtocol.processEnded(self, reason)
        if not self.master._spawned:
            self.master.all_slaves_dead.callback(self)

class TestMasterProtocol(MasterProtocol):
    ping_period = 2
    ping_timeout = 1

    def ping(self):
        if self.factory.master.dead:
            return defer.Deferred()

        return MasterProtocol.ping(self)
    Ping.responder(ping)

class TestMasterFactory(MasterFactory):
    protocol = TestMasterProtocol

class TestMaster(Master):
    serverFactory = TestMasterFactory
    slaveProcessProtocolFactory = StubSlaveProcessProtocol

    def __init__(self, address=None, slave_runner=None):
        Master.__init__(self, address, slave_runner)
        self.dead_processes = []
        self.timeout = 0
        self.all_slaves_dead = defer.Deferred()
        self.dead = False

class TestSlaveProtocol(SlaveProtocol):
    def ping(self):
        pings = getattr(self, 'pings', 0)
        timeout = self.factory.timeout
        if timeout and pings == timeout:
            return defer.Deferred()

        self.pings = pings + 1

        return SlaveProtocol.ping(self)
    Ping.responder(ping)

class TestSlaveFactory(SlaveFactory):
    protocol = TestSlaveProtocol

    def __init__(self, cookie, timeout):
        SlaveFactory.__init__(self, cookie)
        self.timeout = timeout

def dying_runner(cookie, connection_string):
    """
    Runner that does nothing and dies.
    """
    return 1

def blocked_runner(cookie, connection_string):
    """
    Runner that doesn't respond to pings, not even the first.
    """
    from twisted.internet import reactor

    reactor.run()

def okayish_runner(cookie, connection_string, disconnect=False, timeout=0):
    from twisted.internet import reactor

    tokens = connection_string.split(':', 3)
    assert tokens[0] in ('tcp', 'unix')

    if tokens[0] == 'tcp':
        host, port = tokens[1:]
        port = int(port)
        connector = reactor.connectTCP(host, port,
                TestSlaveFactory(cookie, timeout))
    else:
        address = tokens[1]
        connector = reactor.connectUNIX(address,
                TestSlaveFactory(cookie, timeout))

    if disconnect:
        reactor.callLater(2, connector.disconnect)

    reactor.run()

def disconnecting_runner(cookie, connection_string):
    okayish_runner(cookie, connection_string, True)

def timeout_runner(cookie, connection_string):
    okayish_runner(cookie, connection_string, timeout=3)

class MasterMixin(object):

    def tearDown(self):
        return self.master.stop()

    def _resetRunner(self, runner):
        self.master._slave_runner = runner

    def setRunner(self, runner):
        current = self.master._slave_runner
        self.master._slave_runner = '%s.%s' % (__name__, runner)
        self.addCleanup(self._resetRunner, current)

    def testStartSlavesFail(self):
        """
        Call startSlaves() and make all the slaves fail to start. Check that
        startSlaves() errbacks when all the processes are ended.
        """
        def startSlavesCb(result):
            self.failUnlessEqual(len(self.master.dead_processes), 2)
            return self.master.stopSlaves()

        self.setRunner('dying_runner')
        dfr = self.master.startSlaves(2)
        self.failUnlessFailure(dfr, StartError)
        dfr.addCallback(startSlavesCb) 

        return dfr

    def testStartStopSlaves(self):
        """
        Start and stop slaves.
        Check that the slaves are started correctly and that they die when
        master.stopSlaves() is called.
        """
        def slavesStoppedCb(result):
            self.failUnlessEqual(len(self.master.dead_processes), 2)
            return self.master.stop()

        def slavesStartedCb(result):
            dfr = self.master.stopSlaves()
            dfr.addCallback(slavesStoppedCb)

            return dfr

        self.setRunner('okayish_runner')
        dfr = self.master.startSlaves(2)
        dfr.addCallback(slavesStartedCb)

        return dfr

    def testStartSlavesSpawnTimeout(self):
        """
        Start slaves that don't connect to the master, resulting in a timeout.
        """
        def startSlavesEb(failure):
            self.failUnlessEqual(len(self.master.dead_processes), 2)
            for process, reason in self.master.dead_processes:
                self.failUnlessEqual(reason.value.signal, 9)

            return self.master.stop()

        self.setRunner('blocked_runner')
        dfr = self.master.startSlaves(2)
        dfr.addErrback(startSlavesEb)

        return dfr

    def testSlavesDisconnect(self):
        """
        Start slaves that disconnect from the master after a while.
        """
        def slavesDeadCb(result):
            return self.master.stopSlaves().addCallback(lambda result:
                self.master.stop())

        def slavesStoppedCb(result):
            self.failUnlessEqual(len(self.master.dead_processes), 2)
            return self.master.stop()

        def slavesStartedCb(result):
            dfr = self.master.all_slaves_dead
            dfr.addCallback(slavesDeadCb)

            return dfr

        self.setRunner('disconnecting_runner')
        dfr = self.master.startSlaves(2)
        dfr.addCallback(slavesStartedCb)

        return dfr

    def testSlavesPingTimeout(self):
        """
        Run a slave that stays connected but stops answering ping requests after
        a while.
        """
        def slavesDeadCb(result):
            return self.master.stopSlaves().addCallback(lambda result:
                self.master.stop())

        def slavesStoppedCb(result):
            self.failUnlessEqual(len(self.master.dead_processes), 2)
            return self.master.stop()

        def slavesStartedCb(result):
            dfr = self.master.all_slaves_dead
            dfr.addCallback(slavesDeadCb)

            return dfr

        self.setRunner('timeout_runner')
        dfr = self.master.startSlaves(2)
        dfr.addCallback(slavesStartedCb)

        return dfr

    def testMasterPingTimeout(self):
        """
        check that a slave kills itself if the master does not answer to pings
        """
        
        def slavesStoppedCb(result):
            self.failUnlessEqual(len(self.master.dead_processes), 2)
            return self.master.stop()

        def slavesStartedCb(result):
            dfr = self.master.all_slaves_dead
            dfr.addCallback(slavesStoppedCb)

            return dfr

        self.setRunner('okayish_runner')
        self.master.dead = True
        dfr = self.master.startSlaves(2)
        dfr.addCallback(slavesStartedCb)

        return dfr

class UnixMasterTestCase(MasterMixin, TestCase):
    """
    Test the Master with unix sockets
    """
    def setUp(self):
        if platform.system() != 'Linux':
            raise SkipTest("This is only supported in Linux")
        self.master = TestMaster(address='unix:')
        self.master.start()


class TCPMasterTestCase(MasterMixin, TestCase):
    """
    Test the Master with TCP
    """
    def setUp(self):
        self.master = TestMaster(address='tcp:')
        self.master.start()
