#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c)2003, Matthias A. Benkard.

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

from core.common import _

from core import common

import logging
import re
from gtk import gdk
import gtk

log = logging.getLogger("core.parser")

ALPHANUMERIC = tuple("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789%.")
FUNCTION = "func"
PAREN = "paren"
PRIMITIVE = "primitive"
PRIORITY = {
	"=": -1,
	":=": -1,
	":": -1,
	"+": 0,
	"-": 0,
	"*": 1,
	"/": 1,
	"^": 2,
	"**": 2,
	FUNCTION: 100,
	PAREN: 100,
	PRIMITIVE: 100
}


def parse(expr):
	paren_level = 0
	paren_children = []
	bracket_level = 0
	bracket_children = []
	objs = []
	ops = []
	identifier = ''
	function = False
	bc = ''
	count = -1
	
	for c in expr:
		count += 1
		
		if paren_level > 0:
			if not (c == '(' or c == ')' or c == ','):
				paren_content += c
				continue
			else:
				paren_content += c
		elif bracket_level > 0:
			if not (c == '[' or c == ']' or c == ','):
				bracket_content += c
				continue
			else:
				bracket_content += c
		
		if c in tuple("+-*/^"):
			# An operator.
			if identifier == '':
				# Examples:
				#  +2
				#  -3
				#  (a)+b
				#  3 + (-3)  # this will be skipped anyway
				#  2^-3
				if len(objs) < 1 or bc in tuple("+-*/^") and not (bc == '*' and c == '*'):
					# Just pretend the operator is a part
					# of the identifier.
					identifier += c
				else:
					if bc == '*' and c == '*':
						ops[-1] = '**'
					else:
						ops += [c]
			else:
				ops += [c]
				objs += [parse(identifier)]
				identifier = ''
		
		
		elif c == '(':
			# Opening paren.
			paren_level += 1
			if paren_level == 1:
				paren_content = ''
				if identifier != '' and paren_level == 1:
					# Caught a function.
					function = True
		
		elif c == ')':
			paren_level -= 1
			if paren_level == 0:
				# Recursively parse paren contents.
				a = map(parse, paren_children)
				a += [parse(paren_content)]
				if function:
					#log.debug(_("Caught function: '%s (%s)'") % (identifier, a))
					objs += [newfunction(identifier, a)]
					function = False
					identifier = ''
				else:
					#log.debug(_("Caught an isolated pair of parentheses."))
					objs += a
					identifier = ''
				paren_content = ''
		
		elif c == '[':
			bracket_level += 1
			if bracket_level == 1:
				bracket_content = ''
		
		elif c == ']':
			bracket_level -= 1
			if bracket_level == 0:
				a = List()
				a += map(parse, bracket_children)
				a += [parse(bracket_content)]
				return a
		
		elif c == ',':
			if bracket_level == 1:
				bracket_children += [bracket_content]
				bracket_content = ''
			elif paren_level == 1:
				if not function:
					raise Exception("Unexpected token: ,")
				paren_children += [paren_content]
				paren_content = ''
		
		
		elif c == ':':
			if identifier:
				return Assignment(expr[:count], expr[count + 1:])
			else:
				bc = c
				continue
		
		
		elif c == '=':
			# Oh, an equation! This is starting to get interesting :)
			# We shall assume there can only be two parts of an
			# equation, the left hand one and the right hand one.
			if bc == ':':
				return Definition(expr[:count - 1], expr[count + 1:])
			else:
				return Equation(expr[:count], expr[count + 1:])
		
		elif c == ' ':
			continue
		
		else:
			# An identifier.
			if c not in ALPHANUMERIC:
				log.warn(_("'%c' is not an alphanumeric character.") % c)
			identifier += c
		
		bc = c
	
	if identifier != '':
		if len(objs) > 0:
			objs += [parse(identifier)]
		else:
			# The only identifier here. Hmm.
			return Primitive(identifier)
	
	if len(ops) == 0:
		assert len(objs) == 1
		return objs[0]
	
	if len(ops) == 1:
		# Easy deal.
		op = ops[0]
		if op == '+':
			return Sum(objs[0], objs[1])
		elif op == '-':
			return Difference(objs[0], objs[1])
		elif op == '*':
			return Product(objs[0], objs[1])
		elif op == '/':
			return Quotient(objs[0], objs[1])
		elif op == '^' or op == '**':
			return Potence(objs[0], objs[1])
	else:
		return chain(objs, ops)


def chain(objs, ops):
	"""Parse a chain of operations recursively."""
	
	if len(ops) == 0:
		assert(len(objs) == 1)
		return objs[0]
	
	if '^' in ops:
		i = ops.index('^')
		if len(ops) > i + 1:
			for x in ops[i+1:]:
				if x == '^':
					i += 1
				else:
					break
		
		a = objs[i]
		b = objs[i + 1]
		ops[i:i+1] = []
		objs[i:i+2] = [Potence(a, b)]
		return chain(objs, ops)
	
	if '*' in ops or '/' in ops:
		x = '*' in ops
		y = '/' in ops
		if x and y:
			i = ops.index('*')
			j = ops.index('/')
			i = min(i, j)
		elif x:
			i = ops.index('*')
		elif y:
			i = ops.index('/')
		c = i
		
		chobj = [objs[i]]
		chops = []
		for o in ops[i:]:
			if o == '/':
				i += 1
				chops += ['/']
				chobj += [objs[i]]
			elif o == '*':
				i += 1
				chops += ['*']
				chobj += [objs[i]]
			else:
				break
		
		assert len(chops) != 0
		if len(chops) == 1:
			# Easy.
			op = chops[0]
			a = chobj[0]
			b = chobj[1]
			if op == '*':
				obj = Product(a, b)
			else:
				obj = Quotient(a, b)
			ops[i-1:i] = []
			objs[i-1:i+1] = [obj]
			return chain(objs, ops)
		else:
			# Uh-oh.
			obj = ProductChain(chobj, chops)
			ops[c:i] = []
			objs[c:i+1] = [obj]
			return chain(objs, ops)
	
	if '+' in ops or '-' in ops:
		x = '+' in ops
		y = '-' in ops
		if x and y:
			i = ops.index('+')
			j = ops.index('-')
			i = min(i, j)
		elif x:
			i = ops.index('+')
		elif y:
			i = ops.index('-')
		c = i
		
		chobj = [objs[i]]
		chops = []
		for o in ops[i:]:
			if o == '-':
				i += 1
				chops += ['-']
				chobj += [objs[i]]
			elif o == '+':
				i += 1
				chops += ['+']
				chobj += [objs[i]]
			else:
				break
		
		assert len(chops) != 0
		if len(chops) == 1:
			# Easy.
			op = chops[0]
			a = chobj[0]
			b = chobj[1]
			if op == '+':
				obj = Sum(a, b)
			else:
				obj = Difference(a, b)
			ops[i-1:i] = []
			objs[i-1:i+1] = [obj]
			return chain(objs, ops)
		else:
			# Uh-oh.
			obj = SumChain(chobj, chops)
			ops[c:i] = []
			objs[c:i+1] = [obj]
			return chain(objs, ops)


def newfunction(name, args):
	"""Deal with some special functions."""
	
	if name == 'sqrt':
		return Potence(args[0], parse("1/2"))
	else:
		return Function(name, args)


# Implementation inheritance is a Good Thing.
class Drawable:
	def set_relative_position(self, x, y):
		self.x = x
		self.y = y
	
	def get_relative_position(self):
		return x, y


def draw_operation_chain(self, widget, style):
	# p1, op1, p2, op2, p3, op3, p4...
	font = gdk.font_from_description(style.font_desc)
	
	if '/' in self.ops:
		# We must draw a quotient. Uh... this may become hairy.
		
		# Seperate factors with exponents >0 and those with
		# exponents >0.
		factors = [self.operands[0]]
		divisors = []
		for i in xrange(0, len(self.ops)):
			if self.ops[i] == '/':
				divisors += [self.operands[i + 1]]
			else:
				factors += [self.operands[i + 1]]
		
		if len(factors) > 1:
			tmp = Drawable()
			tmp.operands = factors
			tmp.ops = ['*'] * (len(factors) - 1)
			tmp.state = self.state
			draw_operation_chain(tmp, widget, style)
			factorpixmap = tmp.pixmap
		else:
			factorpixmap = factors[0].draw(widget, style)
		
		if len(divisors) > 1:
			tmp = Drawable()
			tmp.operands = divisors
			tmp.ops = ['*'] * (len(divisors) - 1)
			tmp.state = self.state
			draw_operation_chain(tmp, widget, style)
			divisorpixmap = tmp.pixmap
		else:
			divisorpixmap = divisors[0].draw(widget, style)
		
		# Phew. So far so good.
		# Let's proceed.
		# Now the only thing left is drawing the beast.
		w1, h1 = factorpixmap.get_size()
		w2, h2 = divisorpixmap.get_size()
		
		w = l = max(w1, w2) + 8
		h = h1 + h2 + 5
		
		y1 = 0
		y  = h1 + 2
		y2 = h - h2
		
		x1 = (w + w1) / 2 - w1
		x  = 0
		x2 = (w + w2) / 2 - w2
		
		self.pixmap = gdk.Pixmap(widget.window, w, h)
		self.pixmap.draw_rectangle(
					style.bg_gc[self.state], True,
					0, 0,
					w, h)
		
		# Draw the division line.
		self.pixmap.draw_line(
				style.fg_gc[self.state],
				x, y,
				x + l, y)
		
		# Draw the operands.
		self.pixmap.draw_drawable(
				style.fg_gc[self.state], factorpixmap,
				0, 0,
				x1, y1,
				w1, h1)
		self.pixmap.draw_drawable(
				style.fg_gc[self.state], divisorpixmap,
				0, 0,
				x2, y2,
				w2, h2)
		
	else:
		# A "normal" operation chain.
		
		tokens = []
		for i in xrange(0, len(self.ops)):
			p = self.operands[i]
			
			# We assume that a chain always consists of operations
			# that have the _same_ priority. Thus it's irrelevant
			# which element of self.ops we test the other tokens
			# against.
			if p.priority >= PRIORITY[self.ops[0]]:
				tokens += [p.draw(widget, style)]
			else:
				tokens += [Paren(p).draw(widget, style)]
			
			op = self.ops[i]
			w = font.string_width(" %s " % op)
			h = font.ascent + font.descent
			pixmap = gdk.Pixmap(widget.window, w, h)
			pixmap.draw_rectangle(
					style.bg_gc[self.state], True,
					0, 0,
					w, h)
			pixmap.draw_text(
					font, style.fg_gc[gtk.STATE_NORMAL],
					0, font.ascent,
					" %s " % op)
			tokens += [pixmap]
		
		p = self.operands[-1]
		if p.priority >= PRIORITY[self.ops[0]]:
			tokens += [p.draw(widget, style)]
		else:
			tokens += [Paren(p).draw(widget, style)]
		
		w = 0
		h = 0
		# Determine chain dimensions.
		for token in tokens:
			pw, ph = token.get_size()
			w += pw
			h = max(h, ph)
		
		# Create pixmap.
		self.pixmap = gdk.Pixmap(widget.window, w, h)
		self.pixmap.draw_rectangle(
				style.bg_gc[gtk.STATE_NORMAL], True,
				0, 0,
				w, h)
		
		# Draw the chain.
		xpos = 0
		for token in tokens:
			pw, ph = token.get_size()
			ypos = (h + ph) / 2 - ph
			self.pixmap.draw_drawable(
					style.fg_gc[gtk.STATE_NORMAL], token,
					0, 0,
					xpos, ypos,
					pw, ph)
			xpos += pw
	
	
class Sum (Drawable):
	def __init__(self, summand1, summand2):
		self.operands = (summand1, summand2)
		self.ops = ["+"]
		self.priority = PRIORITY["+"]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		draw_operation_chain(self, widget, style)
		return self.pixmap


class SumChain (Drawable):
	def __init__(self, operands, ops):
		# operands and ops are lists/tuples.
		self.operands = tuple(operands)
		self.ops = tuple(ops)
		self.priority = PRIORITY["+"]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		draw_operation_chain(self, widget, style)
		return self.pixmap


class Difference (Drawable):
	def __init__(self, subtrahend, minuend):
		self.operands = (subtrahend, minuend)
		self.ops = ["-"]
		self.priority = PRIORITY["-"]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		draw_operation_chain(self, widget, style)
		return self.pixmap


class Product (Drawable):
	def __init__(self, factor1, factor2):
		self.operands = (factor1, factor2)
		self.ops = ["*"]
		self.priority = PRIORITY["*"]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		draw_operation_chain(self, widget, style)
		return self.pixmap


class ProductChain (Drawable):
	def __init__(self, operands, ops):
		# operands and ops are lists/tuples.
		self.operands = tuple(operands)
		self.ops = tuple(ops)
		self.priority = PRIORITY["*"]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		draw_operation_chain(self, widget, style)
		return self.pixmap


class Quotient (Drawable):
	def __init__(self, dividend, divisor):
		self.operands = (dividend, divisor)
		self.ops = ["/"]
		self.priority = PRIORITY["/"]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		draw_operation_chain(self, widget, style)
		return self.pixmap


class Potence (Drawable):
	def __init__(self, base, exponent):
		self.operands = (base, exponent)
		self.priority = PRIORITY["^"]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		# Check priority and add parentheses if necessary.
		if self.operands[0].priority >= self.priority:
			op1 = self.operands[0]
		else:
			op1 = Paren(self.operands[0])
		pmb = op1.draw(widget, style)
		wb, hb = pmb.get_size()
		
		# Check complexity of exponent and add parantheses
		# if sensible.
		if self.operands[1].priority >= PRIORITY[PRIMITIVE]:
			op2 = self.operands[1]
		else:
			op2 = Paren(self.operands[1])
		pme = op2.draw(widget, style)
		we, he = pme.get_size()
		
		h = hb + max(0, he - hb*2/3)
		w = wb + we
		
		self.pixmap = gdk.Pixmap(widget.window, w, h)
		self.pixmap.draw_rectangle(
				style.bg_gc[self.state], True,
				0, 0,
				w, h)
		
		self.pixmap.draw_drawable(
				style.fg_gc[self.state], pmb,
				0, 0,
				0, h - hb,
				wb, hb)
		self.pixmap.draw_drawable(
				style.fg_gc[self.state], pme,
				0, 0,
				wb, 0,
				we, he)
		
		return self.pixmap


class Equation (Drawable):
	def __init__(self, term1, term2):
		self.terms = (term1, term2)
		self.operands = map(parse, self.terms)
		self.priority = PRIORITY["="]
		self.ops = ["="]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		draw_operation_chain(self, widget, style)
		return self.pixmap


class Definition (Drawable):
	def __init__(self, term1, term2):
		self.terms = (term1, term2)
		self.operands = map(parse, self.terms)
		self.priority = PRIORITY[":="]
		self.ops = [":="]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		draw_operation_chain(self, widget, style)
		return self.pixmap


class Assignment (Drawable):
	def __init__(self, term1, term2):
		self.terms = (term1, term2)
		self.operands = map(parse, self.terms)
		self.priority = PRIORITY[":"]
		self.ops = [":"]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		draw_operation_chain(self, widget, style)
		return self.pixmap


class Function (Drawable):
	def __init__(self, name, args):
		self.args = tuple(args)
		self.name = Primitive(name)
		self.priority = PRIORITY[FUNCTION]
		self.state = gtk.STATE_NORMAL
	
	# TODO
	def draw(self, widget, style):
		# self.name + Paren, but with commas (see List).
		
		font = gdk.font_from_description(style.font_desc)
		
		sw = font.string_width(", ")
		sh = font.ascent + font.descent
		
		cpm = []
		w = 0
		h = sh
		# Determine dimensions.
		for x in self.args:
			pm = x.draw(widget, style)
			cpm += [pm]
			pw, ph = pm.get_size()
			w += pw
			h = max(h, ph)
		
		# self.name
		namepm = self.name.draw(widget, style)
		nw, nh = namepm.get_size()
		w += nw
		h = max(h, nh)
		
		h += 2
		pw = int(h/6)
		w += 2*pw + 3
		w += sw * (len(self.args) - 1)
		
		self.pixmap = gdk.Pixmap(widget.window, w, h)
		self.pixmap.draw_rectangle(
				style.bg_gc[self.state], True,
				0, 0,
				w, h)
		
		# Draw the name.
		self.pixmap.draw_drawable(
				style.fg_gc[self.state], namepm,
				0, 0,
				0, (h + nh)/2 - nh,
				nw, nh)
				
		xpos = nw + 1
		
		# Left paren.
		self.pixmap.draw_polygon(
				style.fg_gc[self.state], False,
				((xpos + pw, 0), (xpos, int(h/6)), (xpos, int(h-1 - h/6)), (xpos + pw, h-1),
				 (xpos, int(h-1 - h/6)), (xpos, int(h/6))))
		
		# Right paren.
		self.pixmap.draw_polygon(
				style.fg_gc[self.state], False,
				((w-1 - pw, 0), (w-1, int(h/6)), (w-1, int(h-1 - h/6)), (w-1 - pw, h-1),
				 (w-1, int(h-1 - h/6)), (w-1, int(h/6))))
		
		# Content, comma-seperated.
		xpos += pw + 1
		for pm in cpm:
			pw, ph = pm.get_size()
			ypos = h - ph
			self.pixmap.draw_drawable(
					style.fg_gc[self.state], pm,
					0, 0,
					xpos, ypos,
					pw, ph)
			
			xpos += pw
			
			if (cpm.index(pm) < len(cpm) - 1):
				self.pixmap.draw_text(
						font, style.fg_gc[self.state],
						xpos, h - font.descent, ", ")
				xpos += sw
		
		return self.pixmap


class Paren (Drawable):
	def __init__(self, expr):
		self.expr = expr
		self.priority = PRIORITY[PAREN]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		"""Draws a pair of parens with something inside."""
		objpm = self.expr.draw(widget, style)
		ow, oh = objpm.get_size()
		
		h = oh + 2
		pw = int(h/6)
		w = ow + 2*pw
		
		self.pixmap = gdk.Pixmap(widget.window, w, h)
		self.pixmap.draw_rectangle(
				style.bg_gc[self.state], True,
				0, 0,
				w, h)
		
		# Content.
		self.pixmap.draw_drawable(
				style.fg_gc[self.state], objpm,
				0, 0,
				int(w/2 - ow/2), int(h/2 - oh/2),
				ow, oh)
		
		# Left paren.
		self.pixmap.draw_polygon(
				style.fg_gc[self.state], False,
				((pw, 0), (0, int(h/6)), (0, int(h-1 - h/6)), (pw, h-1),
				 (0, int(h-1 - h/6)), (0, int(h/6))))
		
		# Right paren.
		self.pixmap.draw_polygon(
				style.fg_gc[self.state], False,
				((w-1 - pw, 0), (w-1, int(h/6)), (w-1, int(h-1 - h/6)), (w-1 - pw, h-1),
				 (w-1, int(h-1 - h/6)), (w-1, int(h/6))))
		
		return self.pixmap


class List (Drawable, list):
	def __init__(self, arg = []):
		list.__init__(self, arg)
		self.state = gtk.STATE_NORMAL
		self.priority = PRIORITY[PAREN]
	
	def draw(self, widget, style):
		font = gdk.font_from_description(style.font_desc)
		
		sw = font.string_width(", ")
		sh = font.ascent + font.descent
		
		cpm = []
		w = 0
		h = sh
		# Determine dimensions.
		for x in self:
			pm = x.draw(widget, style)
			cpm += [pm]
			pw, ph = pm.get_size()
			w += pw
			h = max(h, ph)
		
		pw = int(h/6)
		w += sw * (len(self) - 1) + pw*2 + 2
		h += 4
		
		self.pixmap = gdk.Pixmap(widget.window, w, h)
		self.pixmap.draw_rectangle(
				style.bg_gc[self.state], True,
				0, 0,
				w, h)
		
		# Left square bracket.
		self.pixmap.draw_polygon(
				style.fg_gc[self.state], False,
				((pw, 1), (0, 1), (0, h-2), (pw, h-2),
				 (0, h-2), (0, 1)))
		
		# Right square bracket.
		self.pixmap.draw_polygon(
				style.fg_gc[self.state], False,
				((w-1 - pw, 1), (w-1, 1), (w-1, h-2), (w-1 - pw, h-2),
				 (w-1, h-2), (w-1, 1)))
		
		# Content, comma-seperated.
		xpos = pw + 1
		for pm in cpm:
			pw, ph = pm.get_size()
			ypos = h - ph
			self.pixmap.draw_drawable(
					style.fg_gc[self.state], pm,
					0, 0,
					xpos, ypos,
					pw, ph)
			
			xpos += pw
			
			if (cpm.index(pm) < len(cpm) - 1):
				self.pixmap.draw_text(
						font, style.fg_gc[self.state],
						xpos, h - font.descent, ", ")
				xpos += sw
		
		return self.pixmap


class Primitive (Drawable):
	def __init__(self, name):
		self.name = name
		self.priority = PRIORITY[PRIMITIVE]
		self.state = gtk.STATE_NORMAL
	
	def draw(self, widget, style):
		font = gdk.font_from_description(style.font_desc)
		
		w = font.string_width(self.name)
		h = font.ascent + font.descent
		
		self.pixmap = gdk.Pixmap(widget.window, w, h)
		self.pixmap.draw_rectangle(
				style.bg_gc[self.state], True,
				0, 0,
				w, h)
		
		self.pixmap.draw_text(
				font, style.fg_gc[self.state],
				0, font.ascent,
				self.name)
		
		return self.pixmap

