python-plus-plus/ppp_parser.py

371 lines
16 KiB
Python

from typing import Callable, Dict, List, Optional, Tuple
from ppp_lexer import Lexer
from ppp_tokens import IdentifierToken, Keyword, KeywordToken, NumberToken, StringToken, Symbol, SymbolToken
from ppp_ast import *
def parse_identifier(lexer: Lexer) -> str:
identifier = lexer.assert_tokenkind(IdentifierToken)
assert isinstance(identifier.contents, IdentifierToken)
return identifier.contents.identifier
def parse_number(lexer: Lexer) -> int:
number = lexer.assert_tokenkind(NumberToken)
assert isinstance(number.contents, NumberToken)
return number.contents.number
def parse_string(lexer: Lexer) -> str:
string = lexer.assert_tokenkind(StringToken)
assert isinstance(string.contents, StringToken)
return string.contents.string
def parse_type_primary(lexer: Lexer) -> TypeExpression:
base_type: TypeExpression
if lexer.take_token(SymbolToken(Symbol.Open)):
if lexer.take_token(SymbolToken(Symbol.Close)): return TupleTypeExpr([])
types: List[TypeExpression] = [parse_type(lexer)]
while lexer.take_token(SymbolToken(Symbol.Comma)):
types.append(parse_type(lexer))
lexer.assert_token(SymbolToken(Symbol.Close))
base_type = TupleTypeExpr(types)
elif lexer.take_token(SymbolToken(Symbol.OpenSquare)):
type = parse_type(lexer)
lexer.assert_token(SymbolToken(Symbol.CloseSquare))
base_type = ListTypeExpr(type)
else:
name = parse_identifier(lexer)
base_type = TypeName(name)
while (opening_token := lexer.take_tokens(SymbolToken(Symbol.OpenSquare), SymbolToken(Symbol.Left))):
assert isinstance(opening_token.contents, SymbolToken)
opening = opening_token.contents.symbol
if opening == Symbol.OpenSquare and lexer.check_tokenkind(NumberToken):
number = parse_number(lexer)
lexer.assert_token(SymbolToken(Symbol.CloseSquare))
base_type = ArrayTypeExpr(base_type, number)
continue
opening2closing_map: Dict[Symbol, Symbol] = {
Symbol.OpenSquare: Symbol.CloseSquare,
Symbol.Left: Symbol.Right
}
assert opening in opening2closing_map, "Unreachable"
closing = opening2closing_map[opening]
if opening == Symbol.OpenSquare and lexer.take_token(SymbolToken(closing)):
base_type = ListTypeExpr(base_type)
continue
generics: List[TypeExpression] = [parse_type(lexer)]
while lexer.take_token(SymbolToken(Symbol.Comma)): generics.append(parse_type(lexer))
lexer.assert_token(SymbolToken(closing))
assert not isinstance(base_type, TypeSpecification)
base_type = TypeSpecification(base_type, generics)
return base_type
def parse_type(lexer: Lexer) -> TypeExpression:
base_type = parse_type_primary(lexer)
if not lexer.take_token(SymbolToken(Symbol.Arrow)): return base_type
return_type = parse_type(lexer)
return FunctionTypeExpr([base_type] if not isinstance(base_type, TupleTypeExpr) else base_type.types, return_type)
def parse_type_declaration(lexer: Lexer) -> TypeDeclaration:
entry_name = parse_identifier(lexer)
lexer.assert_token(SymbolToken(Symbol.Colon))
entry_type = parse_type(lexer)
return TypeDeclaration(entry_name, entry_type)
def parse_enum_entry(lexer: Lexer) -> EnumEntry:
entry_name = parse_identifier(lexer)
if not lexer.take_token(SymbolToken(Symbol.Open)):
return EnumEntry(entry_name, [])
entry_types: List[TypeExpression] = [parse_type(lexer)]
while lexer.take_token(SymbolToken(Symbol.Comma)):
entry_types.append(parse_type(lexer))
lexer.assert_token(SymbolToken(Symbol.Close))
return EnumEntry(entry_name, entry_types)
def parse_primary(lexer: Lexer) -> Expression:
base_expression: Expression
if lexer.take_token(SymbolToken(Symbol.Open)):
if lexer.take_token(SymbolToken(Symbol.Close)):
base_expression = TupleExpr([])
else:
elements: List[Expression] = [parse_expression(lexer)]
singleton: bool = False
while lexer.take_token(SymbolToken(Symbol.Comma)):
if lexer.check_token(SymbolToken(Symbol.Close)) and len(elements) == 1:
singleton = True
break
elements.append(parse_expression(lexer))
lexer.assert_token(SymbolToken(Symbol.Close))
if singleton or len(elements) > 1:
base_expression = TupleExpr(elements)
else:
base_expression = elements[0]
elif lexer.take_token(SymbolToken(Symbol.OpenSquare)):
lexer.assert_token(SymbolToken(Symbol.Colon))
element_type = parse_type(lexer)
if lexer.take_token(SymbolToken(Symbol.CloseSquare)):
base_expression = Array(element_type, [])
else:
lexer.assert_token(SymbolToken(Symbol.Comma))
expressions: List[Expression] = [parse_expression(lexer)]
if lexer.take_token(KeywordToken(Keyword.For)):
variable = parse_identifier(lexer) # TODO: Pattern matching
lexer.assert_token(KeywordToken(Keyword.In))
expression = parse_expression(lexer)
lexer.assert_token(SymbolToken(Symbol.CloseSquare))
base_expression = LoopComprehension(element_type, expressions[0], variable, expression)
else:
while lexer.take_token(SymbolToken(Symbol.Comma)):
expressions.append(parse_expression(lexer))
lexer.assert_token(SymbolToken(Symbol.CloseSquare))
base_expression = Array(element_type, expressions)
elif lexer.check_tokenkind(StringToken):
base_expression = String(parse_string(lexer))
elif lexer.check_tokenkind(NumberToken):
base_expression = Number(parse_number(lexer))
else:
base_expression = Variable(parse_identifier(lexer))
while (token := lexer.take_tokens(SymbolToken(Symbol.Open), SymbolToken(Symbol.OpenSquare), SymbolToken(Symbol.Dot))):
match token.contents:
case SymbolToken(symbol):
match symbol:
case Symbol.Dot:
next_token = lexer.next_token()
match next_token.contents:
case IdentifierToken(identifier=field):
base_expression = FieldAccess(base_expression, field)
case SymbolToken(symbol=symbol):
match symbol:
case Symbol.OpenCurly:
if lexer.take_token(SymbolToken(Symbol.CloseCurly)):
base_expression = StructInstantiation(base_expression, [])
else:
def parse_argument() -> Tuple[str, Expression]:
parameter = parse_identifier(lexer)
lexer.assert_token(SymbolToken(Symbol.Equal))
return (parameter, parse_expression(lexer))
struct_arguments: List[Tuple[str, Expression]] = [parse_argument()]
while lexer.take_token(SymbolToken(Symbol.Comma)): struct_arguments.append(parse_argument())
lexer.assert_token(SymbolToken(Symbol.CloseCurly))
base_expression = StructInstantiation(base_expression, struct_arguments)
case _:
raise SyntaxError(f"{next_token.loc}: Unexpected symbol: {repr(str(symbol))}")
case _:
raise SyntaxError(f"{next_token.loc}: Unexpected: {next_token.contents}")
case Symbol.Open:
if lexer.take_token(SymbolToken(Symbol.Close)):
base_expression = FunctionCall(base_expression, [])
else:
arguments: List[Expression] = [parse_expression(lexer)]
while lexer.take_token(SymbolToken(Symbol.Comma)):
arguments.append(parse_expression(lexer))
lexer.assert_token(SymbolToken(Symbol.Close))
base_expression = FunctionCall(base_expression, arguments)
case Symbol.OpenSquare:
index = parse_expression(lexer)
lexer.assert_token(SymbolToken(Symbol.CloseSquare))
base_expression = ArrayAccess(base_expression, index)
case _: assert False, ("Unimplemented", symbol)
case _: assert False, ("Unimplemented", token)
return base_expression
def parse_unary(lexer: Lexer) -> Expression:
if lexer.take_token(SymbolToken(Symbol.Tilde)): return Bnot(parse_unary(lexer))
if lexer.take_token(SymbolToken(Symbol.Exclamation)): return Not(parse_unary(lexer))
if lexer.take_token(SymbolToken(Symbol.Plus)): return UnaryPlus(parse_unary(lexer))
if lexer.take_token(SymbolToken(Symbol.Dash)): return UnaryMinus(parse_unary(lexer))
return parse_primary(lexer)
Precedence = Dict[Symbol, Callable[[Expression, Expression], Expression]]
precedences: List[Precedence] = [
{Symbol.Dpipe: Or},
{Symbol.Dampersand: And},
{Symbol.Pipe: Bor},
{Symbol.Carot: Bxor},
{Symbol.Ampersand: Band},
{Symbol.Dequal: Equal, Symbol.NotEqual: NotEqual},
{Symbol.Left: LessThan, Symbol.Right: GreaterThan, Symbol.LesserEqual: LessThanOrEqual, Symbol.GreaterEqual: GreaterThanOrEqual},
{Symbol.Dleft: ShiftLeft, Symbol.Dright: ShiftRight},
{Symbol.Plus: Addition, Symbol.Dash: Subtract},
{Symbol.Asterisk: Multiplication, Symbol.Slash: Division, Symbol.Percent: Modulo}
]
def parse_expression_at_level(lexer: Lexer, level: int=0) -> Expression:
if level >= len(precedences): return parse_unary(lexer)
left = parse_expression_at_level(lexer, level+1)
tokens = [SymbolToken(symbol) for symbol in precedences[level]]
while (token := lexer.take_tokens(*tokens)):
assert isinstance(token.contents, SymbolToken)
left = precedences[level][token.contents.symbol](left, parse_expression_at_level(lexer, level+1))
return left
def parse_ternary(lexer: Lexer) -> Expression:
expression = parse_expression_at_level(lexer)
if not lexer.take_token(SymbolToken(Symbol.QuestionMark)): return expression
if_true = parse_expression_at_level(lexer)
lexer.assert_token(SymbolToken(Symbol.Colon))
if_false = parse_ternary(lexer)
return Ternary(expression, if_true, if_false)
def parse_expression(lexer: Lexer) -> Expression:
if lexer.take_token(KeywordToken(Keyword.Lambda)):
parameters: List[TypeDeclaration]
if lexer.take_token(SymbolToken(Symbol.EqualArrow)):
parameters = []
else:
parameters = [parse_type_declaration(lexer)]
while lexer.take_token(SymbolToken(Symbol.Comma)):
parameters.append(parse_type_declaration(lexer))
lexer.assert_token(SymbolToken(Symbol.EqualArrow))
return Lambda(parameters, parse_expression(lexer))
return parse_ternary(lexer)
def is_valid_target(expression: Expression) -> bool:
match expression:
case FieldAccess(subexpression, _): return is_valid_target(subexpression)
case Variable(_): return True
case ArrayAccess(array, _): return is_valid_target(array)
case _: assert False, ("Unimplemeneted", expression)
assert False, "Unreachable"
def parse_statement(lexer: Lexer) -> Statement:
if lexer.take_token(KeywordToken(Keyword.Enum)):
enum_name = parse_identifier(lexer)
lexer.assert_token(SymbolToken(Symbol.OpenCurly))
if lexer.take_token(SymbolToken(Symbol.CloseCurly)): return EnumDefinition(enum_name, [])
enum_entries: List[EnumEntry] = [parse_enum_entry(lexer)]
while lexer.take_token(SymbolToken(Symbol.Comma)):
enum_entries.append(parse_enum_entry(lexer))
lexer.assert_token(SymbolToken(Symbol.CloseCurly))
return EnumDefinition(enum_name, enum_entries)
elif lexer.take_token(KeywordToken(Keyword.Struct)):
struct_name = parse_identifier(lexer)
lexer.assert_token(SymbolToken(Symbol.OpenCurly))
if lexer.take_token(SymbolToken(Symbol.CloseCurly)): return StructDefinition(struct_name, [])
struct_entries: List[TypeDeclaration] = [parse_type_declaration(lexer)]
while lexer.take_token(SymbolToken(Symbol.Comma)):
struct_entries.append(parse_type_declaration(lexer))
lexer.assert_token(SymbolToken(Symbol.CloseCurly))
return StructDefinition(struct_name, struct_entries)
elif lexer.take_token(KeywordToken(Keyword.Func)):
function_name = parse_identifier(lexer)
lexer.assert_token(SymbolToken(Symbol.Open))
function_arguments: List[TypeDeclaration] = []
if not lexer.take_token(SymbolToken(Symbol.Close)):
function_arguments.append(parse_type_declaration(lexer))
while lexer.take_token(SymbolToken(Symbol.Comma)):
function_arguments.append(parse_type_declaration(lexer))
lexer.assert_token(SymbolToken(Symbol.Close))
function_return_type: Optional[TypeExpression] = None
if lexer.take_token(SymbolToken(Symbol.Arrow)):
function_return_type = parse_type(lexer)
function_body = parse_statement(lexer)
return FunctionDefinition(function_name, function_arguments, function_return_type, function_body)
elif lexer.take_token(KeywordToken(Keyword.If)):
return IfStatement(
parse_expression(lexer),
parse_statement(lexer),
parse_statement(lexer) if lexer.take_token(KeywordToken(Keyword.Else)) else None
)
elif lexer.take_token(KeywordToken(Keyword.Else)):
assert False, "Unmatched else"
elif lexer.take_token(KeywordToken(Keyword.While)):
return WhileStatement(
parse_expression(lexer),
parse_statement(lexer)
)
elif lexer.take_token(KeywordToken(Keyword.Break)):
lexer.assert_token(SymbolToken(Symbol.Semicolon))
return BreakStatement()
elif lexer.take_token(KeywordToken(Keyword.Continue)):
lexer.assert_token(SymbolToken(Symbol.Semicolon))
return ContinueStatement()
elif lexer.take_token(KeywordToken(Keyword.Return)):
expression = parse_expression(lexer)
lexer.assert_token(SymbolToken(Symbol.Semicolon))
return ReturnStatement(expression)
elif lexer.take_token(KeywordToken(Keyword.Do)):
body = parse_statement(lexer)
condition: Optional[Expression] = None
if lexer.take_token(KeywordToken(Keyword.While)):
condition = parse_expression(lexer)
lexer.assert_token(SymbolToken(Symbol.Semicolon))
return DoWhileStatement(body, condition)
elif lexer.take_token(KeywordToken(Keyword.Match)):
value = parse_expression(lexer)
lexer.assert_token(SymbolToken(Symbol.OpenCurly))
cases: List[Tuple[Expression, Statement]] = []
while lexer.take_token(KeywordToken(Keyword.Case)): cases.append((parse_expression(lexer), parse_statement(lexer)))
lexer.assert_token(SymbolToken(Symbol.CloseCurly))
return MatchStatement(value, cases)
elif lexer.take_token(KeywordToken(Keyword.Assert)):
condition = parse_expression(lexer)
message = parse_expression(lexer) if lexer.take_token(SymbolToken(Symbol.Comma)) else None
lexer.assert_token(SymbolToken(Symbol.Semicolon))
return AssertStatement(condition, message)
elif lexer.take_token(KeywordToken(Keyword.For)):
variable = parse_identifier(lexer) # TODO: Allow for pattern matching here
lexer.assert_token(KeywordToken(Keyword.In))
expression = parse_expression(lexer)
body = parse_statement(lexer)
return ForLoop(variable, expression, body)
elif lexer.take_token(KeywordToken(Keyword.Import)):
file = parse_string(lexer)
lexer.assert_token(SymbolToken(Symbol.Semicolon))
return Import(file)
elif lexer.take_token(KeywordToken(Keyword.Type)):
name = parse_identifier(lexer)
lexer.assert_token(SymbolToken(Symbol.Equal))
type_expression = parse_type(lexer)
lexer.assert_token(SymbolToken(Symbol.Semicolon))
return TypeDefinition(name, type_expression)
elif lexer.take_token(KeywordToken(Keyword.Defer)):
statement = parse_statement(lexer)
return DeferStatement(statement)
elif lexer.check_tokenkind(KeywordToken) and not lexer.check_token(KeywordToken(Keyword.Lambda)): # TODO: Maybe use '\' for lambda instead of a keyword
token = lexer.next_token()
assert isinstance(token.contents, KeywordToken)
raise SyntaxError(f"{token.loc}: Unexpected keyword: '{token.contents.keyword}'")
elif lexer.take_token(SymbolToken(Symbol.OpenCurly)):
statements: List[Statement] = []
while not lexer.take_token(SymbolToken(Symbol.CloseCurly)):
statements.append(parse_statement(lexer))
return Statements(statements)
else:
expression = parse_expression(lexer)
type: Optional[TypeExpression] = None
if lexer.take_token(SymbolToken(Symbol.Colon)):
assert isinstance(expression, Variable), "Cannot declare types for anything besides a variable"
type = parse_type(lexer)
if lexer.take_token(SymbolToken(Symbol.Equal)):
assert is_valid_target(expression), ("Invalid target!", expression)
right_expression = parse_expression(lexer)
lexer.assert_token(SymbolToken(Symbol.Semicolon))
return Assignment(expression, right_expression, type)
lexer.assert_token(SymbolToken(Symbol.Semicolon))
if type and isinstance(expression, Variable):
return TypeDeclarationStatement(TypeDeclaration(expression.name, type))
return ExpressionStatement(expression)