#! /usr/bin/env python
#
# ufw: front-end for Linux firewalling
#
# Copyright (C) 2008 Canonical Ltd.
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License version 3,
#    as published by the Free Software Foundation.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import os
import re
import shutil
import socket
import stat
from stat import *
import subprocess
import sys
from tempfile import mkstemp

version = "#VERSION#"
programName = "ufw"

if sys.version_info[0] < 2 or \
   (sys.version_info[0] == 2 and sys.version_info[1] < 5):
    print >> sys.stderr, programName + ": Need at least python 2.5\n"
    sys.exit(1)

# These are default settings
files = {'defaults': '#CONFIG_PREFIX#/default/ufw', 
         'conf': '#CONFIG_PREFIX#/ufw/ufw.conf' }
debugging = False
disable_checks = False


def process_args():
    '''Process command line arguments'''
    action = ""
    rule = ""
    type = ""
    from_type = "any"
    to_type = "any"
    from_service = ""
    to_service = ""
    dryrun = False

    if len(sys.argv) > 1 and sys.argv[1].lower() == "--dry-run":
        dryrun = True
        sys.argv.remove("--dry-run")

    remove = False
    if len(sys.argv) > 1 and sys.argv[1].lower() == "delete":
        remove = True
        sys.argv.remove("delete")

    nargs = len(sys.argv)

    if nargs < 2:
        print_help()
        sys.exit(1)

    allowed_cmds = ['enable', 'disable', 'help', '--help', 'default', \
                    'logging', 'status', 'version', '--version', 'allow', \
                    'deny' ]

    if not sys.argv[1].lower() in allowed_cmds:
        print_help()
        sys.exit(1)
    else:
        action = sys.argv[1].lower()

    if action == "logging":
        if nargs < 3:
            print_help()
            sys.exit(1)
        elif sys.argv[2].lower() == "off":
            action = "logging-off"
        elif sys.argv[2].lower() == "on":
            action = "logging-on"
        else:
            print_help()
            sys.exit(1)

    if action == "default":
        if nargs < 3:
            print_help()
            sys.exit(1)
        elif sys.argv[2].lower() == "deny":
            action = "default-deny"
        elif sys.argv[2].lower() == "allow":
            action = "default-allow"
        else:
            print_help()
            sys.exit(1)

    if action == "allow" or action == "deny":
        if nargs < 3 or nargs > 12:
            print_help()
            sys.exit(1)
        
        rule = UFWRule(action, "any", "any")
        if remove:
            rule.remove = remove
        if nargs == 3:
            # Short form where only port/proto is given
            try:
                (port, proto) = parse_port_proto(sys.argv[2])
                if not re.match('^\d+$', port):
                    to_service = port
                rule.set_protocol(proto)
                rule.set_port(port, "dst")
                type = "both"
            except UFWError:
                raise UFWError("Bad port")
        elif nargs % 2 != 0:
            raise UFWError("Wrong number of arguments")
        elif not 'from' in sys.argv and not 'to' in sys.argv:
            raise UFWError("Need 'to' or 'from' clause")
        else:
            # Full form with PF-style syntax
            keys = [ 'proto', 'from', 'to', 'port' ]

            # quick check
            if sys.argv.count("to") > 1 or \
               sys.argv.count("from") > 1 or \
               sys.argv.count("proto") > 1 or \
               sys.argv.count("port") > 2:
                raise UFWError("Improper rule syntax")

            i = 1
            loc = ""
            for arg in sys.argv[1:]:
                if i % 2 == 0 and sys.argv[i] not in keys:
                    raise UFWError("Invalid token '" + sys.argv[i] + "'")
                if arg == "proto":
                    if i+1 < nargs:
                        try:
                            rule.set_protocol(sys.argv[i+1])
                        except:
                            raise
                    else:
                        raise UFWError("Invalid 'proto' clause")
                elif arg == "from":
                    if i+1 < nargs:
                        try:
                            faddr = sys.argv[i+1].lower()
                            if faddr == "any":
                                faddr = "0.0.0.0/0"
                                from_type = "any"
                            else:
                                if valid_address(faddr, True):
                                    from_type = "v6"
                                else:
                                    from_type = "v4"
                            rule.set_src(faddr)
                        except:
                            raise
                        loc = "src"
                    else:
                        raise UFWError("Invalid 'from' clause")
                elif arg == "to":
                    if i+1 < nargs:
                        try:
                            saddr = sys.argv[i+1].lower()
                            if saddr == "any":
                                saddr = "0.0.0.0/0"
                                to_type = "any"
                            else:
                                if valid_address(saddr, True):
                                    to_type = "v6"
                                else:
                                    to_type = "v4"
                            rule.set_dst(saddr)
                        except:
                            raise
                        loc = "dst"
                    else:
                        raise UFWError("Invalid 'to' clause")
                elif arg == "port":
                    if i+1 < nargs:
                        if loc == "":
                            raise UFWError("Need 'from' or 'to' with 'port'")

                        tmp = sys.argv[i+1]
                        if not re.match('^\d+$', tmp):
                            if loc == "src":
                                from_service = tmp
                            else:
                                to_service = tmp

                        try:
                            rule.set_port(tmp, loc)
                        except:
                            raise
                    else:
                        raise UFWError("Invalid 'port' clause")
                i += 1

            # Figure out the type of rule (IPv4, IPv6, or both) this is
            if from_type == "any" and to_type == "any":
                type = "both"
            elif from_type != "any" and to_type != "any" and \
                 from_type != to_type:
                raise UFWError("Mixed IP versions for 'from' and 'to'")
            elif from_type != "any":
                type = from_type
            elif to_type != "any":
                type = to_type

    # Adjust protocol
    if to_service != "" or from_service != "":
        proto = ""
        if to_service != "":
            proto = get_services_proto(to_service)
        if from_service != "":
            if proto == "any" or proto == "":
                proto = get_services_proto(from_service)
            else:
                tmp = get_services_proto(from_service)
                if proto == "any" or proto == tmp:
                    proto = tmp
                elif tmp == "any":
                    pass
                else:
                    raise UFWError("Protocol mismatch (from/to)")
        
        # Verify found proto with specified proto
        if rule.protocol == "any":
            rule.set_protocol(proto) 
        elif proto != "any" and rule.protocol != proto:
            raise UFWError("Protocol mismatch with specified protocol " + \
                           rule.protocol)

    return (action, rule, type, dryrun)


def get_services_proto(port):
    '''Get the protocol for a specified port from /etc/services'''
    proto = ""
    try:
        socket.getservbyname(port)
    except:
        raise

    try:
        socket.getservbyname(port, "tcp")
        proto = "tcp"
    except:
        pass

    try:
        socket.getservbyname(port, "udp")
        if proto == "tcp":
            proto = "any"
        else:
            proto = "udp"
    except:
        pass

    return proto

def parse_port_proto(str):
    '''Parse port or port and protocol'''
    port = ""
    proto = ""
    tmp = str.split('/')
    if len(tmp) == 1:
        port = tmp[0]
        proto = "any"
    elif len(tmp) == 2:
        port = tmp[0]
        proto = tmp[1]
    else:
        raise UFWError("Bad port/protocol")
    return (port, proto)


def valid_address(addr, v6=False):
    '''Validate IP addresses'''
    if v6 and not socket.has_ipv6:
        warn("python does not have IPv6 support.")
        return False

    # quick and dirty test
    if len(addr) > 43 or not re.match(r'^[a-fA-F0-9:\./]+$', addr):
        return False

    net = addr.split('/')

    if len(net) > 2:
        return False
    elif len(net) == 2:
        # Check netmask specified via '/'

        if not re.match(r'^[0-9]+$', net[1]):
            # Only allow integer netmasks
            return False

        if v6:
            if int(net[1]) < 0 or int(net[1]) > 128:
                return False
        else:
            if int(net[1]) < 0 or int(net[1]) > 32:
                return False

    try:
        if v6:
            socket.inet_pton(socket.AF_INET6, net[0])
        else:
            socket.inet_pton(socket.AF_INET, net[0])
    except:
        return False
    
    return True

def print_help():
    '''Print help message'''
    print '''
Usage: ''' + programName + ''' COMMAND

Commands:
  enable			Enables the firewall
  disable			Disables the firewall
  default ARG			set default policy to ALLOW or DENY
  logging ARG			set logging to ON or OFF
  allow|deny RULE		allow or deny RULE
  delete allow|deny RULE	delete the allow/deny RULE
  status			show firewall status
  version			display version information
'''

def open_file_read(f):
    '''Opens the specified file read-only'''
    try:
        orig = open(f, 'r')
    except OSError, e:
        raise UFWError("Couldn't open '" + f + "' for reading")
    except Exception:
        raise

    return orig


def open_files(f):
    '''Opens the specified file read-only and a tempfile read-write.'''
    orig = open_file_read(f)

    try:
        (tmp, tmpname) = mkstemp()
    except Exception:
        orig.close()
        raise

    return { "orig": orig, "origname": f, "tmp": tmp, "tmpname": tmpname }


def close_files(fns, update = True):
    '''Closes the specified files (as returned by open_files), and update
       original file with the temporary file.
    '''
    fns['orig'].close()
    os.close(fns['tmp'])

    if update:
        try:
            shutil.copystat(fns['origname'], fns['tmpname'])
            shutil.copy(fns['tmpname'], fns['origname'])
        except Exception:
            raise

    try:
        os.unlink(fns['tmpname'])
    except OSError, e:
        raise


def cmd(command):
    '''Try to execute the given command.'''
    try:
        sp = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    except OSError, e:
        return [127, str(e)]

    out = sp.communicate()[0]
    return [sp.returncode,out]


def cmd_pipe(command1, command2):
    '''Try to pipe command1 into command2.'''
    try:
        sp1 = subprocess.Popen(command1, stdout=subprocess.PIPE)
        sp2 = subprocess.Popen(command2, stdin=sp1.stdout)
    except OSError, e:
        return [127, str(e)]

    out = sp2.communicate()[0]
    return [sp2.returncode,out]


def error(msg):
    '''Print error message and exit'''
    print >> sys.stderr, "ERROR: " + msg
    sys.exit(1)


def warn(msg):
    '''Print warning message'''
    print >> sys.stderr, "WARN: " + msg


def debug(msg):
    '''Print debug message'''
    if debugging:
        print >> sys.stderr, "DEBUG: " + msg

#
# Classes
#
class UFWError(Exception):
    '''This class represents ufw exceptions'''
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return repr(self.value)


class UFWRule:
    '''This class represents firewall rules'''
    def __init__(self, action, protocol, dport="any", dst="0.0.0.0/0",
                 sport="any", src="0.0.0.0/0"):
        self.remove = False
        self.v6 = False
        self.dst = ""
        self.src = ""
        try:
            self.set_action(action)
            self.set_protocol(protocol)
            self.set_port(dport)
            self.set_port(sport, "src")
            self.set_src(src)
            self.set_dst(dst)
        except UFWError:
            raise

    def __str__(self):
        print self.format_rule()

    def format_rule(self):
        '''Format rule for for later parsing'''
        str = ""

        # Protocol is handled below
        if self.protocol == "any":
            str = " -p all"
        else:
            str = " -p " + self.protocol

        if self.dst != "0.0.0.0/0" and self.dst != "::/0":
            str += " -d " + self.dst
        if self.dport != "any":
            str += " --dport " + self.dport
        if self.src != "0.0.0.0/0" and self.src != "::/0":
            str += " -s " + self.src
        if self.sport != "any":
            str += " --sport " + self.sport
        if self.action == "allow":
            str += " -j ACCEPT"
        else:
            str += " -j DROP"

        return str.strip()

    def set_action(self, action):
        '''Sets action of the rule'''
        if action.lower() == "allow":
            self.action = action
        else:
            self.action = "deny"

    def set_port(self, port, loc="dst"):
        '''Sets port and location (destination or source) of the rule'''
        if port == "any":
            pass
        elif re.match('^\d+$', port):
            if int(port) < 1 or int(port) > 65535:
                raise UFWError("Bad port '" + port + "'")
        elif re.match(r'^\w[\w\-]+', port):
            try:
                port = socket.getservbyname(port)
            except Exception, (error):
                raise UFWError("Bad port '" + port + "'")
        else:
            raise UFWError("Bad port '" + port + "'")

        if loc == "src":
            self.sport = str(port)
        else:
            self.dport = str(port)

    def set_protocol(self, protocol):
        '''Sets protocol of the rule'''
        if protocol == "tcp" or protocol == "udp" or protocol == "any":
            self.protocol = protocol
        else:
            raise UFWError("Unsupported protocol '" + protocol + "'")

    def _fix_anywhere(self):
        '''Adjusts src and dst based on v6'''
        if self.v6:
            if self.dst and (self.dst == "any" or self.dst == "0.0.0.0/0"):
                self.dst = "::/0"
            if self.src and (self.src == "any" or self.src == "0.0.0.0/0"):
                self.src = "::/0"
        else:
            if self.dst and (self.dst == "any" or self.dst == "::/0"):
                self.dst = "0.0.0.0/0"
            if self.src and (self.src == "any" or self.src == "::/0"):
                self.src = "0.0.0.0/0"

    def set_v6(self, v6):
        '''Sets whether this is ipv6 rule, and adjusts src and dst 
           accordingly.
        '''
        self.v6 = v6
        self._fix_anywhere()

    def set_src(self, addr):
        '''Sets source address of rule'''
        tmp = addr.lower()

        if tmp != "any" and not valid_address(tmp) and \
           not valid_address(tmp, True):
            raise UFWError("Bad source address")
        self.src = tmp
        self._fix_anywhere()

    def set_dst(self, addr):
        '''Sets destination address of rule'''
        tmp = addr.lower()

        if tmp != "any" and not valid_address(tmp) and \
           not valid_address(tmp, True):
            raise UFWError("Bad destination address")
        self.dst = tmp
        self._fix_anywhere()

    def match(x, y):
        '''Check if rules match
        Return codes:
          0  match
          1  no match
         -1  match all but action
        '''
        if x.dport != y.dport:
            debug("No match")
            return 1
        if x.sport != y.sport:
            debug("No match")
            return 1
        if x.protocol != y.protocol:
            debug("No match")
            return 1
        if x.src != y.src:
            debug("No match")
            return 1
        if x.dst != y.dst:
            debug("No match")
            return 1
        if x.v6 != y.v6:
            debug("No match")
            return 1
        
        if x.action == y.action:
            debug("Found exact match")
            return 0
        debug("Found opposite match")
        return -1


class UFWBackend:
    '''Interface for backends'''
    def __init__(self, name, d):
        self.defaults = {}
        self.name = name
        self.dryrun = d
        self.rules = []
        self.rules6 = []
        try:
            self._do_checks()
            self._get_defaults()
            self._read_rules()
        except Exception:
            raise

    def _is_enabled(self):
        if self.defaults.has_key('enabled') and \
           self.defaults['enabled'] == 'yes':
            return True
        return False

    def use_ipv6(self):
        if self.defaults.has_key('ipv6') and \
           self.defaults['ipv6'] == 'yes' and \
           os.path.exists("/proc/sys/net/ipv6"):
            return True
        return False

    def _do_checks(self):
        '''Perform basic security checks:
        is setuid or setgid (for non-Linux systems)
        checks that script is owned by root
        checks that every component in absolute path are owned by root
        checks that every component of absolute path are not a symlink
        warn if script is group writable
        warn if part of script path is group writable

        Doing this at the beginning causes a race condition with later
        operations that don't do these checks.  However, if the user running
        this script is root, then need to be root to exploit the race 
        condition (and you are hosed anyway...)
        '''

        if disable_checks:
            warn("Checks disabled")
            return True

        # Not needed on Linux, but who knows the places we will go...
        if os.getuid() != os.geteuid():
            raise UFWError("ERROR: this script should not be SUID")
        if os.getgid() != os.getegid():
            raise UFWError("ERROR: this script should not be SGID")
        uid = os.getuid()

        if uid != 0:
            raise UFWError("You need to be root to run this script")

        pat = re.compile(r'^\.')
        for path in files.values() + [ os.path.abspath(sys.argv[0]) ]:
            while True:
                debug("Checking " + path)
                if pat.search(os.path.basename(path)):
                    raise UFWError("found hidden directory in path: " + path)

                try:
                    statinfo = os.stat(path)
                    mode = statinfo[ST_MODE]
                except OSError, e:
                    raise UFWError("Couldn't stat '" + path + "'")
                except Exception:
                    raise

                if os.path.islink(path):
                    raise UFWError("found symbolic link in path: " + path)
                if statinfo.st_uid != 0:
                    raise UFWError("uid is " + str(uid) + " but '" + path + \
                                   "' is owned by " + str(statinfo.st_uid))
                if mode & S_IWOTH:
                    raise UFWError(path + " is world writable!")
                if mode & S_IWGRP:
                    warn(path + " is group writable")

                if path == "/":
                    break

                path = os.path.dirname(path)
                if not path:
                    raise

        for f in files:
            if not os.path.isfile(files[f]):
                raise UFWError("'" + f + "' file '" + files[f] + \
                               "' does not exist")

    def _get_defaults(self):
        '''Get all settings from defaults file'''
        self.defaults = {}
        for f in [files['defaults'], files['conf']]:
            orig = open_file_read(f)
            pat = re.compile(r'^\w+=\w+')
            for line in orig:
                if pat.search(line):
                    tmp = re.split(r'=', line.strip())
                    self.defaults[tmp[0].lower()] = tmp[1].lower()

            orig.close()

    def set_default(self, f, opt, value):
        '''Sets option in defaults file'''
        if not re.match(r'^[\w_]+$', opt):
            raise UFWError("Invalid option")

        try:
            fns = open_files(f)
        except Exception:
            raise
        fd = fns['tmp']

        pat = re.compile(r'^' + opt + '=')
        for line in fns['orig']:
            if pat.search(line):
                os.write(fd, opt + "=" + value + "\n")
            else:
                os.write(fd, line)
    
        close_files(fns)

    # API overrides
    def get_loglevel(self):
        raise UFWError("UFWBackend.get_loglevel: need to override")

    def set_loglevel(self, level):
        raise UFWError("UFWBackend.set_loglevel: need to override")

    def set_default_policy(self, policy):
        raise UFWError("UFWBackend.set_default_policy: need to override")

    def get_status(self):
        raise UFWError("UFWBackend.get_status: need to override")

    def set_rule(self, rule):
        raise UFWError("UFWBackend.set_rule: need to override")

    def start_firewall(self):
        raise UFWError("UFWBackend.start_firewall: need to override")

    def stop_firewall(self):
        raise UFWError("UFWBackend.stop_firewall: need to override")


class UFWBackendIptables(UFWBackend):
    def __init__(self, d):
        files['rules'] = '#STATE_PREFIX#/user.rules'
        files['before_rules'] = '#CONFIG_PREFIX#/ufw/before.rules'
        files['after_rules'] = '#CONFIG_PREFIX#/ufw/after.rules'
        files['rules6'] = '#STATE_PREFIX#/user6.rules'
        files['before6_rules'] = '#CONFIG_PREFIX#/ufw/before6.rules'
        files['after6_rules'] = '#CONFIG_PREFIX#/ufw/after6.rules'
        files['init'] = '#CONFIG_PREFIX#/init.d/ufw'

        UFWBackend.__init__(self, "iptables", d)

    def get_loglevel(self):
        '''Show current log level of firewall'''
        print "get_loglevel: TODO"

    def set_default_policy(self, policy):
        '''Sets default policy of firewall'''
        if not self.dryrun:
            if policy == "allow":
                self.set_default(files['defaults'], "DEFAULT_INPUT_POLICY", \
                                 "\"ACCEPT\"")
            elif policy == "deny":
                self.set_default(files['defaults'], "DEFAULT_INPUT_POLICY", 
                                "\"DROP\"")
            else:
                raise UFWError("Unsupported policy '" + policy + "'")

        rstr = "Default policy changed to '" + policy + "'\n" + \
              "(be sure to update your rules accordingly)"

        return rstr

    def set_loglevel(self, level):
        '''Sets log level of firewall'''
        comment_str = "# " + programName + "_comment #"
        for f in [files['rules'], files['rules6'], \
                  files['before_rules'], files['before6_rules'], \
                  files['after_rules'], files['after6_rules']]:
            try:
                fns = open_files(f)
            except Exception:
                raise
            fd = fns['tmp']

            pat = re.compile(r'^-.*\sLOG\s')
            if level == "on":
                pat = re.compile(r'^#.*\sLOG\s')

            if not self.dryrun:
                for line in fns['orig']:
                    if pat.search(line):
                        if level == "off":
                            os.write(fd, comment_str + ' ' + line)
                        else:
                            pat_comment = re.compile(r"^" + comment_str + "\s*")
                            os.write(fd, pat_comment.sub('', line))
                    else:
                        os.write(fd, line)

            if self.dryrun:
                close_files(fns, False)
            else:
                close_files(fns)
    
        if level == "off":
            return "Logging disabled"
        else:
            return "Logging enabled"

    def get_status(self):
        '''Show current status of firewall'''
        out = ""
        out6 = ""
        if dryrun:
            out = "> Checking iptables\n"
            if self.use_ipv6():
                out += "> Checking ip6tables\n"
            return out

        # Is the firewall loaded at all?
        (rc, out) = cmd(['iptables', '-L', 'ufw-user-input', '-n'])
        if rc != 0:
            return "Firewall not loaded"

        # Get the output of iptables for parsing
        (rc, out) = cmd(['iptables', '-L', '-n'])
        if rc != 0:
            raise UFWError("problem running iptables")

        if self.use_ipv6():
            (rc, out6) = cmd(['ip6tables', '-L', 'ufw6-user-input', '-n'])
            if rc != 0:
                raise UFWError("problem running ip6tables")
            if out6 == "":
                return out6

        if out == "" and out6 == "":
            return "Firewall loaded"

        str = ""
        rules = []
        pat_chain = re.compile(r'^Chain ')
        pat_target = re.compile(r'^target')
        for type in ["v4", "v6"]:
            pat_ufw = re.compile(r'^Chain ufw-user-input')
            if type == "v6":
                pat_ufw = re.compile(r'^Chain ufw6-user-input')
            lines = out
            if type == "v6":
                lines = out6
            in_ufw_input = False
            for line in lines.split('\n'):
                if pat_ufw.search(line):
                    in_ufw_input = True
                    continue
                elif pat_chain.search(line):
                    in_ufw_input = False
                    continue
                elif pat_target.search(line):
                    pass
                elif in_ufw_input:
                    r = self._parse_iptables_status(line, type)
                    if r is not None:
                        rules.append(r)

        for r in rules:
            location = {}
            for loc in [ 'dst', 'src' ]:
                location[loc] = ""

                port = r.dport
                tmp = r.dst
                if loc == 'src':
                    port = r.sport
                    tmp = r.src
                
                if tmp != "0.0.0.0/0" and tmp != "::/0":
                    location[loc] = tmp
                
                if port != "any":
                    if location[loc] == "":
                        location[loc] = port
                    else:
                        location[loc] += " " + port

                    if r.protocol != "any":
                        location[loc] += ":" + r.protocol

                if port == "any":
                    if tmp == "0.0.0.0/0":
                        location[loc] = "Anywhere"
                    if tmp == "::/0":
                        location[loc] = "Anywhere (v6)"

            str += "%-26s %-8s%s\n" % (location['dst'], r.action.upper(), \
                    location['src'])

        if str != "":
            tmp = "\n\n%-26s %-8s%s\n" % ("To", "Action", "From")
            tmp += "%-26s %-8s%s\n" % ("--", "------", "----")
            str = tmp + str

        return "Firewall loaded" + str

    def stop_firewall(self):
        '''Stops the firewall'''
        openconf = '''*filter
:INPUT ACCEPT [0:0]
:FORWARD ACCEPT [0:0]
:OUTPUT ACCEPT [0:0]
COMMIT
'''
        if dryrun:
            print "> iptables -F"
            print "> iptables -X"
            print "> echo\n" + openconf + "> | iptables-restore"
            if self.use_ipv6():
                print "> ip6tables -F"
                print "> ip6tables -X"
                print "> echo\n" + openconf + "> | ip6tables-restore"
            return

        try:
            (tmp, tmpname) = mkstemp()
        except Exception:
            raise
        os.write(tmp, openconf)
        os.close(tmp)

        # Flush the firewall
        (rc, out) = cmd(['iptables', '-F'])
        if rc != 0:
            raise UFWError("problem running iptables 'flush'")

        # Remove user chains
        (rc, out) = cmd(['iptables', '-X'])
        if rc != 0:
            raise UFWError("problem running iptables 'delete'")

        # Set default open
        (rc, out) = cmd_pipe(['cat', tmpname], ['iptables-restore'])
        if rc != 0:
            raise UFWError("problem running iptables")

        if self.use_ipv6():
            # Flush the firewall
            (rc, out) = cmd(['ip6tables', '-F'])
            if rc != 0:
                raise UFWError("problem running ip6tables 'flush'")

            # Remove user chains
            (rc, out) = cmd(['ip6tables', '-X'])
            if rc != 0:
                raise UFWError("problem running ip6tables 'delete'")

            # Set default open
            (rc, out) = cmd_pipe(['cat', tmpname], ['ip6tables-restore'])
            if rc != 0:
                raise UFWError("problem running ip6tables")

        os.unlink(tmpname)

    def start_firewall(self):
        '''Starts the firewall'''
        if dryrun:
            print "> running initscript"
        else:
            (rc, out) = cmd([files['init'], 'start'])
            if rc != 0:
                raise UFWError("problem running init script")

    def _reload_user_rules(self):
        '''Reload firewall rules file'''
        if self.dryrun:
            print "> iptables-restore < " + files['rules'] 
            if self.use_ipv6():
                print "> ip6tables-restore < " + files['rules6'] 
        else:
            (rc, out) = cmd_pipe(['cat', files['rules']], \
                                 ['iptables-restore', '-n'])
            if rc != 0:
                raise UFWError("problem running iptables")

            if self.use_ipv6():
                (rc, out) = cmd_pipe(['cat', files['rules6']], \
                                     ['ip6tables-restore', '-n'])
                if rc != 0:
                    raise UFWError("problem running ip6tables")

    def _get_rules_from_formatted(self, frule):
        '''Return list of iptables rules appropriate for sending'''
        snippets = []
        pat_proto = re.compile(r'-p all ')
        pat_port = re.compile(r'port ')
        if pat_proto.search(frule):
            if pat_port.search(frule):
                snippets.append(pat_proto.sub('-p tcp ', frule))
                snippets.append(pat_proto.sub('-p udp ', frule))
            else:
                snippets.append(pat_proto.sub('', frule))
        else:
            snippets.append(frule)

        return snippets

    def _parse_iptables_status(self, line, type):
        '''Parses a line from iptables -L -n'''
        fields = line.split()

        if type == "v6":
            # ip6tables hack since its opt field is blank (unlike iptables)
            fields.insert(2, '--')

        if len(fields) < 5:
            debug("Couldn't parse line '" + line + "'")
            return None

        rule = UFWRule("ACCEPT", "any", "any")
        if fields[0] == 'ACCEPT':
            rule.set_action('allow')
        elif fields[0] == 'DROP':
            rule.set_action('deny')
        else:
            # RETURN and LOG are valid, but we skip them
            return None

        if fields[1] == 'tcp' or fields[1] == 'udp':
            rule.set_protocol(fields[1])
        elif fields[1] == "0" or fields[1] == "all":
            rule.set_protocol('any')
        else:
            rule.set_protocol('UNKNOWN')

        if type == "v6":
            # ip6tables hack since it doesn't have a space between the 
            # destination address and the protocol on a large destination
            # address (see Debian bug #464244).
            mashed = fields[4][(len(fields[4]) - 3):]
            if mashed == 'tcp' or mashed == 'udp':
                fields.insert(5, mashed)
                fields[4] = fields[4][:(len(fields[4]) - 3)]

        try:
            rule.set_src(fields[3])
            rule.set_dst(fields[4])
        except:
            warn("Couldn't parse line '" + line + "'")
            return None

        if len(fields) >= 7:
            if re.match('dpt', fields[6]):
                rule.set_port(fields[6][4:], "dst")
            elif re.match('spt', fields[6]):
                rule.set_port(fields[6][4:], "src")
        
        if len(fields) >= 8:
            if re.match('dpt', fields[7]):
                rule.set_port(fields[7][4:], "dst")
            elif re.match('spt', fields[7]):
                rule.set_port(fields[7][4:], "src")

        if type == "v6":
            rule.set_v6(True)
        else:
            rule.set_v6(False)
        
        return rule
        
    def _read_rules(self):
        '''Read in rules that were added by ufw.'''
        rfns = [files['rules']]
        if self.use_ipv6():
            rfns.append(files['rules6'])

        for f in rfns:
            orig = open_file_read(f)

            pat_tuple = re.compile(r'^### tuple ###\s*')
            for line in orig:
                if pat_tuple.match(line):
                    tuple = pat_tuple.sub('', line)
                    tmp = re.split(r'\s+', tuple.strip())
                    if len(tmp) != 6:
                        warn("Skipping malformed tuple (bad length): " + tuple)
                    else:
                        try:
                            rule = UFWRule(tmp[0], tmp[1], tmp[2], tmp[3],
                                           tmp[4], tmp[5])
                            if f == files['rules6']:
                                rule.set_v6(True)
                                self.rules6.append(rule)
                            else:
                                rule.set_v6(False)
                                self.rules.append(rule)
                        except UFWError:
                            warn("Skipping malformed tuple: " + tuple)

            orig.close()

    def _write_rules(self, v6=False):
        '''Write out new rules to file to user chain file'''
        rules_file = files['rules']
        if v6:
            rules_file = files['rules6']

        try:
            fns = open_files(rules_file)
        except Exception:
            raise

        chain_prefix = "ufw"
        rules = self.rules
        if v6:
            chain_prefix = "ufw6"
            rules = self.rules6

        if self.dryrun:
            fd = sys.stdout.fileno()
        else:
            fd = fns['tmp']

        pat_commit = re.compile(r'^### RULES ###')
        for line in fns['orig']:
            os.write(fd, line)
            if pat_commit.match(line):
                break

        for r in rules:
            rule_str = "-A " + chain_prefix + "-user-input " + \
                       r.format_rule() + "\n"
            os.write(fd, "\n### tuple ###" + " %s %s %s %s %s %s\n" % \
                     (r.action, r.protocol, r.dport, r.dst, r.sport, r.src))
            for s in self._get_rules_from_formatted(rule_str):
                os.write(fd, s)

        os.write(fd, "\n### END RULES ###\n")
        os.write(fd, "-A " + chain_prefix + "-user-input -j RETURN\n")
        os.write(fd, "-A " + chain_prefix + "-user-output -j RETURN\n")
        os.write(fd, "-A " + chain_prefix + "-user-forward -j RETURN\n")
        os.write(fd, "COMMIT\n")

        if self.dryrun:
            close_files(fns, False)
        else:
            close_files(fns)

    def set_rule(self, rule):
        '''Updates firewall with rule by:
        * appending the rule to the chain if new rule and firewall enabled
        * deleting the rule from the chain if found and firewall enabled
        * updating user rules file
        * reloading the user rules file if rule is modified
        '''
        if rule.v6 and not self.use_ipv6():
            raise UFWError("Adding IPv6 rule failed: IPv6 not enabled")

        newrules = []
        found = False
        modified = False
        delete = False

        rules = self.rules
        if rule.v6:
            rules = self.rules6

        # First construct the new rules list
        for r in rules:
            ret = UFWRule.match(r, rule)
            if ret == 0 and not found:
                # If find the rule, add it if it's not to be removed, otherwise
                # skip it.
                found = True
                if not rule.remove:
                    newrules.append(rule)
            elif ret < 0 and not rule.remove:
                # If only the action is different, replace the rule if it's not
                # to be removed.
                found = True
                modified = True
                newrules.append(rule)
            else:
                newrules.append(r)

        # Add rule to the end if it was not already added.
        if not found and not rule.remove:
            newrules.append(rule)

        if rule.v6:
            self.rules6 = newrules
        else:
            self.rules = newrules

        # Update the user rules file
        try:
            self._write_rules(rule.v6)
        except:
            UFWError("Couldn't update rules file")

        rstr = "Rules updated"
        if rule.v6:
            rstr = "Rules updated (v6)"

        # Operate on the chains
        if self._is_enabled() and not self.dryrun:
            flag = ""
            if modified:
                # Reload the chain
                try:
                    self._reload_user_rules()
                except:
                    raise
                rstr = "Rule updated"
            elif found and rule.remove:
                flag = '-D'
                rstr = "Rule deleted"
            elif not found and not modified and not rule.remove:
                flag = '-A'
                rstr = "Rule added"

            if flag != "":
                exe = "iptables"
                chain = "ufw-user-input"
                if rule.v6:
                    exe = "ip6tables"
                    chain = "ufw6-user-input"
                    rstr += " (v6)"

                # Is the firewall running?
                (rc, out) = cmd([exe, '-L', chain, '-n'])
                if rc != 0:
                    raise UFWError("Could not update running firewall")

                for s in self._get_rules_from_formatted(rule.format_rule()):
                    (rc, out) = cmd([exe, flag, chain] + s.split())
                    if rc != 0:
                        print out
                        UFWError("Could not update running firewall")

                    # delete the RETURN rule then add it back, so it is at the
                    # end
                    if flag == "-A":
                        (rc, out) = cmd([exe, '-D', chain, '-j', 'RETURN'])
                        if rc != 0:
                            print out

                        (rc, out) = cmd([exe, '-A', chain, '-j', 'RETURN'])
                        if rc != 0:
                            print out
        return rstr


class UFWFrontend:
    '''UI'''
    def __init__(self, be):
        self.backend = be

    def set_enabled(self, enabled):
        '''Toggles ENABLED state in of #CONFIG_PREFIX#/ufw/ufw.conf'''
        try:
            if enabled:
                if not self.backend._is_enabled():
                    self.backend.set_default(files['conf'], \
                                             "ENABLED", "yes")
                self.backend.start_firewall()
                print "Firewall started and enabled on system startup"
            else:
                if self.backend._is_enabled():
                    self.backend.set_default(files['conf'], "ENABLED", \
                                             "no")
                self.backend.stop_firewall()
                print "Firewall stopped and disabled on system startup"
        except UFWError, e:
            error(e.value)

    def set_default_policy(self, policy):
        '''Sets default policy of firewall'''
        str = ""
        try:
            str = self.backend.set_default_policy(policy)
            if self.backend._is_enabled():
                self.backend.stop_firewall()
                self.backend.start_firewall()
        except UFWError, e:
            error(e.value)

        print str

    def set_loglevel(self, level):
        '''Sets log level of firewall'''
        str = ""
        try:
            str = self.backend.set_loglevel(level)
            if self.backend._is_enabled():
                # have to just restart because of ordering of LOG rules
                self.backend.stop_firewall()
                self.backend.start_firewall()
        except UFWError, e:
            error(e.value)

        print str

    def get_status(self):
        '''Shows status of firewall'''
        try:
            out = self.backend.get_status()
        except UFWError, e:
            error(e.value)

        print out

    def set_rule(self, rule, type):
        '''Updates firewall with rule'''
        res = ""
        try:
            if self.backend.use_ipv6():
                if type == "v4":
                    rule.set_v6(False)
                    res = self.backend.set_rule(rule)
                elif type == "v6":
                    rule.set_v6(True)
                    res = self.backend.set_rule(rule)
                elif type == "both":
                    rule.set_v6(False)
                    res = self.backend.set_rule(rule)
                    rule.set_v6(True)
                    res += "\n" + str(self.backend.set_rule(rule))
                else:
                    raise UFWError("Invalid type '" + type + "'")
            else:
                if type == "v4" or type == "both":
                    rule.set_v6(False)
                    res = self.backend.set_rule(rule)
                elif type == "v6":
                    raise UFWError("IPv6 support not enabled")
                else:
                    raise UFWError("Invalid type '" + type + "'")
        except UFWError, e:
            error(e.value)

        print res


# Execution starts here
action = ""
rule = ""
dryrun = False
try:
    (action, rule, type, dryrun) = process_args()
except UFWError, e:
    error(e.value)

if action == "help" or action == "--help":
    print_help()
    sys.exit(0)
elif action == "version" or action == "--version":
    print programName + " " + version
    print "Copyright (C) 2008 Canonical Ltd."
    sys.exit(0)

try:
    ufw = UFWFrontend(UFWBackendIptables(dryrun))
except UFWError, e:
    error(e.value)
except Exception:
    raise

if action == "logging-on":
    ufw.set_loglevel("on")
elif action == "logging-off":
    ufw.set_loglevel("off")
elif action == "default-allow":
    ufw.set_default_policy("allow")
elif action == "default-deny":
    ufw.set_default_policy("deny")
elif action == "status":
    ufw.get_status()
elif action == "enable":
    ufw.set_enabled(True)
elif action == "disable":
    ufw.set_enabled(False)
elif action == "allow" or action == "deny":
    ufw.set_rule(rule, type)

sys.exit(0)

