from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

### Types ###

class TypeExpression(ABC):
	@abstractmethod
	def represent(self) -> str: ...

@dataclass
class TupleTypeExpr(TypeExpression):
	types: List[TypeExpression]

	def represent(self) -> str:
		assert False, ("Unimplemented")

@dataclass
class ListTypeExpr(TypeExpression):
	type: TypeExpression

	def represent(self) -> str:
		assert False, ("Unimplemented")

@dataclass
class ArrayTypeExpr(TypeExpression):
	type: TypeExpression
	number: int

	def represent(self) -> str:
		assert False, ("Unimplemented")

@dataclass
class TypeName(TypeExpression):
	name: str

	def represent(self) -> str:
		assert False, ("Unimplemented")

@dataclass
class TypeSpecification(TypeExpression):
	type: TypeExpression
	types: List[TypeExpression]

	def represent(self) -> str:
		assert False, ("Unimplemented")

@dataclass
class FunctionTypeExpr(TypeExpression):
	arguments: List[TypeExpression]
	return_type: TypeExpression

	def represent(self) -> str:
		assert False, ("Unimplemented")

### Statements ###

class Statement:
	pass

@dataclass
class Statements(Statement):
	statements: List[Statement]

### Enums + Struct ###

@dataclass
class EnumEntry:
	name: str
	types: List[TypeExpression]

@dataclass
class EnumDefinition(Statement):
	name: str
	entries: List[EnumEntry]

@dataclass
class TypeDeclaration:
	name: str
	type: TypeExpression

@dataclass
class StructDefinition(Statement):
	name: str
	entries: List[TypeDeclaration]

### Function ###

@dataclass
class FunctionDefinition(Statement):
	name: str
	arguments: list[TypeDeclaration]
	return_type: Optional[TypeExpression]
	body: Statement

### Expressions ###

class Expression(ABC):
	@abstractmethod
	def precedence(self) -> int: ...

	@abstractmethod
	def represent(self) -> str: ...

	def wrap(self, other: 'Expression') -> str:
		if self.precedence() > other.precedence(): return '('+other.represent()+')'
		return other.represent()

@dataclass
class FunctionCall(Expression):
	function: Expression
	arguments: List[Expression]

	def represent(self) -> str:
		return self.wrap(self.function)+"("+', '.join([argument.represent() for argument in self.arguments])+")"

	def precedence(self) -> int: return 13

@dataclass
class Variable(Expression):
	name: str

	def represent(self) -> str:
		return self.name

	def precedence(self) -> int: return 13

@dataclass
class ArrayAccess(Expression):
	array: Expression
	index: Expression

	def represent(self) -> str:
		return self.wrap(self.array)+"["+self.index.represent()+"]"

	def precedence(self) -> int: return 13

@dataclass
class Array(Expression):
	element_type: TypeExpression
	array: List[Expression]

	def represent(self) -> str:
		return "["+', '.join(map(str, self.array))+"]"

	def precedence(self) -> int: return 13

@dataclass
class FieldAccess(Expression):
	expression: Expression
	field: str

	def represent(self) -> str:
		return self.wrap(self.expression)+"."+self.field

	def precedence(self) -> int: return 13

@dataclass
class Number(Expression):
	number: int

	def represent(self) -> str:
		return str(self.number)

	def precedence(self) -> int: return 13

@dataclass
class String(Expression):
	string: str

	def represent(self) -> str:
		return repr(self.string)

	def precedence(self) -> int: return 13

@dataclass
class TupleExpr(Expression):
	elements: List[Expression]

	def represent(self) -> str:
		return f"([{', '.join([element.represent() for element in self.elements])}])"

	def precedence(self) -> int: return 13

@dataclass
class StructInstantiation(Expression):
	struct: Expression
	arguments: List[Tuple[str, Expression]]

	def represent(self) -> str:
		assert False, ("Unimplemented")

	def precedence(self) -> int: return 13

@dataclass
class LoopComprehension(Expression):
	element_type: TypeExpression
	body: Expression
	variable: str # TODO: Pattern matching
	array: Expression

	def represent(self) -> str:
		assert False, ("Unimplemented")

	def precedence(self) -> int: return 13

@dataclass
class Lambda(Expression):
	parameters: List[TypeDeclaration]
	expression: Expression

	def represent(self) -> str:
		assert False, ("Unimplemented")

	def precedence(self) -> int: return 0

@dataclass
class Ternary(Expression):
	condition: Expression
	if_true: Expression
	if_false: Expression

	def represent(self) -> str:
		return self.wrap(self.if_true)+" if "+self.wrap(self.condition)+" else "+self.wrap(self.if_false)

	def precedence(self) -> int: return 1

@dataclass
class Or(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" or "+self.wrap(self.rhs)

	def precedence(self) -> int: return 2

@dataclass
class And(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" and "+self.wrap(self.rhs)

	def precedence(self) -> int: return 3

@dataclass
class Bor(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" | "+self.wrap(self.rhs)

	def precedence(self) -> int: return 4

@dataclass
class Bxor(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" ^ "+self.wrap(self.rhs)

	def precedence(self) -> int: return 5

@dataclass
class Band(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" & "+self.wrap(self.rhs)

	def precedence(self) -> int: return 6

@dataclass
class Equal(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" == "+self.wrap(self.rhs)

	def precedence(self) -> int: return 7

@dataclass
class NotEqual(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" != "+self.wrap(self.rhs)

	def precedence(self) -> int: return 7

@dataclass
class LessThan(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" < "+self.wrap(self.rhs)

	def precedence(self) -> int: return 8

@dataclass
class GreaterThan(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" > "+self.wrap(self.rhs)

	def precedence(self) -> int: return 8

@dataclass
class LessThanOrEqual(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" <= "+self.wrap(self.rhs)

	def precedence(self) -> int: return 8

@dataclass
class GreaterThanOrEqual(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" >= "+self.wrap(self.rhs)

	def precedence(self) -> int: return 8

@dataclass
class ShiftLeft(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" << "+self.wrap(self.rhs)

	def precedence(self) -> int: return 9

@dataclass
class ShiftRight(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" >> "+self.wrap(self.rhs)

	def precedence(self) -> int: return 9


@dataclass
class Addition(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" + "+self.wrap(self.rhs)

	def precedence(self) -> int: return 10

@dataclass
class Subtract(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" - "+self.wrap(self.rhs)

	def precedence(self) -> int: return 10

@dataclass
class Multiplication(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" * "+self.wrap(self.rhs)

	def precedence(self) -> int: return 11

@dataclass
class Division(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" / "+self.wrap(self.rhs)

	def precedence(self) -> int: return 11

@dataclass
class Modulo(Expression):
	lhs: Expression
	rhs: Expression

	def represent(self) -> str:
		return self.wrap(self.lhs)+" % "+self.wrap(self.rhs)

	def precedence(self) -> int: return 11

@dataclass
class Bnot(Expression):
	expression: Expression

	def represent(self) -> str:
		return "~"+self.wrap(self.expression)

	def precedence(self) -> int: return 12

@dataclass
class Not(Expression):
	expression: Expression

	def represent(self) -> str:
		return "!"+self.wrap(self.expression)

	def precedence(self) -> int: return 12

@dataclass
class UnaryPlus(Expression):
	expression: Expression

	def represent(self) -> str:
		return "+"+self.wrap(self.expression)

	def precedence(self) -> int: return 12

@dataclass
class UnaryMinus(Expression):
	expression: Expression

	def represent(self) -> str:
		return "-"+self.wrap(self.expression)

	def precedence(self) -> int: return 12

@dataclass
class ExpressionStatement(Statement):
	expression: Expression

### Assignment + Declaration ###

@dataclass
class Assignment(Statement):
	lhs: Expression
	rhs: Expression
	type: Optional[TypeExpression] = None

@dataclass
class TypeDeclarationStatement(Statement):
	type_declaration: TypeDeclaration

### Control flow ###

@dataclass
class IfStatement(Statement):
	condition: Expression
	body: Statement
	else_body: Optional[Statement]

@dataclass
class WhileStatement(Statement):
	condition: Expression
	body: Statement

@dataclass
class DoWhileStatement(Statement):
	body: Statement
	condition: Optional[Expression]

@dataclass
class BreakStatement(Statement):
	pass

@dataclass
class ContinueStatement(Statement):
	pass

@dataclass
class ReturnStatement(Statement):
	expression: Expression

@dataclass
class MatchStatement(Statement):
	value: Expression
	cases: List[Tuple[Expression, Statement]]

@dataclass
class AssertStatement(Statement):
	condition: Expression
	message: Optional[Expression]

@dataclass
class ForLoop(Statement):
	variable: str # TODO allow for pattern matching
	array: Expression
	body: Statement

@dataclass
class Import(Statement):
	file: str

@dataclass
class TypeDefinition(Statement):
	name: str
	expression: TypeExpression

@dataclass
class DeferStatement(Statement):
        statement: Statement