from typing import Callable, Dict, List, Optional, Tuple from ppp_lexer import Lexer from ppp_tokens import IdentifierToken, Keyword, KeywordToken, NumberToken, StringToken, Symbol, SymbolToken from ppp_ast import * def parse_identifier(lexer: Lexer) -> str: identifier = lexer.assert_tokenkind(IdentifierToken) assert isinstance(identifier.contents, IdentifierToken) return identifier.contents.identifier def parse_number(lexer: Lexer) -> int: number = lexer.assert_tokenkind(NumberToken) assert isinstance(number.contents, NumberToken) return number.contents.number def parse_string(lexer: Lexer) -> str: string = lexer.assert_tokenkind(StringToken) assert isinstance(string.contents, StringToken) return string.contents.string def parse_type_primary(lexer: Lexer) -> TypeExpression: base_type: TypeExpression if lexer.take_token(SymbolToken(Symbol.Open)): if lexer.take_token(SymbolToken(Symbol.Close)): return TupleTypeExpr([]) types: List[TypeExpression] = [parse_type(lexer)] while lexer.take_token(SymbolToken(Symbol.Comma)): types.append(parse_type(lexer)) lexer.assert_token(SymbolToken(Symbol.Close)) base_type = TupleTypeExpr(types) elif lexer.take_token(SymbolToken(Symbol.OpenSquare)): type = parse_type(lexer) lexer.assert_token(SymbolToken(Symbol.CloseSquare)) base_type = ListTypeExpr(type) else: name = parse_identifier(lexer) base_type = TypeName(name) while (opening_token := lexer.take_tokens(SymbolToken(Symbol.OpenSquare), SymbolToken(Symbol.Left))): assert isinstance(opening_token.contents, SymbolToken) opening = opening_token.contents.symbol if opening == Symbol.OpenSquare and lexer.check_tokenkind(NumberToken): number = parse_number(lexer) lexer.assert_token(SymbolToken(Symbol.CloseSquare)) base_type = ArrayTypeExpr(base_type, number) continue opening2closing_map: Dict[Symbol, Symbol] = { Symbol.OpenSquare: Symbol.CloseSquare, Symbol.Left: Symbol.Right } assert opening in opening2closing_map, "Unreachable" closing = opening2closing_map[opening] if opening == Symbol.OpenSquare and lexer.take_token(SymbolToken(closing)): base_type = ListTypeExpr(base_type) continue generics: List[TypeExpression] = [parse_type(lexer)] while lexer.take_token(SymbolToken(Symbol.Comma)): generics.append(parse_type(lexer)) lexer.assert_token(SymbolToken(closing)) assert not isinstance(base_type, TypeSpecification) base_type = TypeSpecification(base_type, generics) return base_type def parse_type(lexer: Lexer) -> TypeExpression: base_type = parse_type_primary(lexer) if not lexer.take_token(SymbolToken(Symbol.Arrow)): return base_type return_type = parse_type(lexer) return FunctionTypeExpr([base_type] if not isinstance(base_type, TupleTypeExpr) else base_type.types, return_type) def parse_type_declaration(lexer: Lexer) -> TypeDeclaration: entry_name = parse_identifier(lexer) lexer.assert_token(SymbolToken(Symbol.Colon)) entry_type = parse_type(lexer) return TypeDeclaration(entry_name, entry_type) def parse_enum_entry(lexer: Lexer) -> EnumEntry: entry_name = parse_identifier(lexer) if not lexer.take_token(SymbolToken(Symbol.Open)): return EnumEntry(entry_name, []) entry_types: List[TypeExpression] = [parse_type(lexer)] while lexer.take_token(SymbolToken(Symbol.Comma)): entry_types.append(parse_type(lexer)) lexer.assert_token(SymbolToken(Symbol.Close)) return EnumEntry(entry_name, entry_types) def parse_primary(lexer: Lexer) -> Expression: base_expression: Expression if lexer.take_token(SymbolToken(Symbol.Open)): if lexer.take_token(SymbolToken(Symbol.Close)): base_expression = TupleExpr([]) else: elements: List[Expression] = [parse_expression(lexer)] singleton: bool = False while lexer.take_token(SymbolToken(Symbol.Comma)): if lexer.check_token(SymbolToken(Symbol.Close)) and len(elements) == 1: singleton = True break elements.append(parse_expression(lexer)) lexer.assert_token(SymbolToken(Symbol.Close)) if singleton or len(elements) > 1: base_expression = TupleExpr(elements) else: base_expression = elements[0] elif lexer.take_token(SymbolToken(Symbol.OpenSquare)): lexer.assert_token(SymbolToken(Symbol.Colon)) element_type = parse_type(lexer) if lexer.take_token(SymbolToken(Symbol.CloseSquare)): base_expression = Array(element_type, []) else: lexer.assert_token(SymbolToken(Symbol.Comma)) expressions: List[Expression] = [parse_expression(lexer)] if lexer.take_token(KeywordToken(Keyword.For)): variable = parse_identifier(lexer) # TODO: Pattern matching lexer.assert_token(KeywordToken(Keyword.In)) expression = parse_expression(lexer) lexer.assert_token(SymbolToken(Symbol.CloseSquare)) base_expression = LoopComprehension(element_type, expressions[0], variable, expression) else: while lexer.take_token(SymbolToken(Symbol.Comma)): expressions.append(parse_expression(lexer)) lexer.assert_token(SymbolToken(Symbol.CloseSquare)) base_expression = Array(element_type, expressions) elif lexer.check_tokenkind(StringToken): base_expression = String(parse_string(lexer)) elif lexer.check_tokenkind(NumberToken): base_expression = Number(parse_number(lexer)) else: base_expression = Variable(parse_identifier(lexer)) while (token := lexer.take_tokens(SymbolToken(Symbol.Open), SymbolToken(Symbol.OpenSquare), SymbolToken(Symbol.Dot))): match token.contents: case SymbolToken(symbol): match symbol: case Symbol.Dot: next_token = lexer.next_token() match next_token.contents: case IdentifierToken(identifier=field): base_expression = FieldAccess(base_expression, field) case SymbolToken(symbol=symbol): match symbol: case Symbol.OpenCurly: if lexer.take_token(SymbolToken(Symbol.CloseCurly)): base_expression = StructInstantiation(base_expression, []) else: def parse_argument() -> Tuple[str, Expression]: parameter = parse_identifier(lexer) lexer.assert_token(SymbolToken(Symbol.Equal)) return (parameter, parse_expression(lexer)) struct_arguments: List[Tuple[str, Expression]] = [parse_argument()] while lexer.take_token(SymbolToken(Symbol.Comma)): struct_arguments.append(parse_argument()) lexer.assert_token(SymbolToken(Symbol.CloseCurly)) base_expression = StructInstantiation(base_expression, struct_arguments) case _: raise SyntaxError(f"{next_token.loc}: Unexpected symbol: {repr(str(symbol))}") case _: raise SyntaxError(f"{next_token.loc}: Unexpected: {next_token.contents}") case Symbol.Open: if lexer.take_token(SymbolToken(Symbol.Close)): base_expression = FunctionCall(base_expression, []) else: arguments: List[Expression] = [parse_expression(lexer)] while lexer.take_token(SymbolToken(Symbol.Comma)): arguments.append(parse_expression(lexer)) lexer.assert_token(SymbolToken(Symbol.Close)) base_expression = FunctionCall(base_expression, arguments) case Symbol.OpenSquare: index = parse_expression(lexer) lexer.assert_token(SymbolToken(Symbol.CloseSquare)) base_expression = ArrayAccess(base_expression, index) case _: assert False, ("Unimplemented", symbol) case _: assert False, ("Unimplemented", token) return base_expression def parse_unary(lexer: Lexer) -> Expression: if lexer.take_token(SymbolToken(Symbol.Tilde)): return Bnot(parse_unary(lexer)) if lexer.take_token(SymbolToken(Symbol.Exclamation)): return Not(parse_unary(lexer)) if lexer.take_token(SymbolToken(Symbol.Plus)): return UnaryPlus(parse_unary(lexer)) if lexer.take_token(SymbolToken(Symbol.Dash)): return UnaryMinus(parse_unary(lexer)) return parse_primary(lexer) Precedence = Dict[Symbol, Callable[[Expression, Expression], Expression]] precedences: List[Precedence] = [ {Symbol.Dpipe: Or}, {Symbol.Dampersand: And}, {Symbol.Pipe: Bor}, {Symbol.Carot: Bxor}, {Symbol.Ampersand: Band}, {Symbol.Dequal: Equal, Symbol.NotEqual: NotEqual}, {Symbol.Left: LessThan, Symbol.Right: GreaterThan, Symbol.LesserEqual: LessThanOrEqual, Symbol.GreaterEqual: GreaterThanOrEqual}, {Symbol.Dleft: ShiftLeft, Symbol.Dright: ShiftRight}, {Symbol.Plus: Addition, Symbol.Dash: Subtract}, {Symbol.Asterisk: Multiplication, Symbol.Slash: Division, Symbol.Percent: Modulo} ] def parse_expression_at_level(lexer: Lexer, level: int=0) -> Expression: if level >= len(precedences): return parse_unary(lexer) left = parse_expression_at_level(lexer, level+1) tokens = [SymbolToken(symbol) for symbol in precedences[level]] while (token := lexer.take_tokens(*tokens)): assert isinstance(token.contents, SymbolToken) left = precedences[level][token.contents.symbol](left, parse_expression_at_level(lexer, level+1)) return left def parse_ternary(lexer: Lexer) -> Expression: expression = parse_expression_at_level(lexer) if not lexer.take_token(SymbolToken(Symbol.QuestionMark)): return expression if_true = parse_expression_at_level(lexer) lexer.assert_token(SymbolToken(Symbol.Colon)) if_false = parse_ternary(lexer) return Ternary(expression, if_true, if_false) def parse_expression(lexer: Lexer) -> Expression: if lexer.take_token(KeywordToken(Keyword.Lambda)): parameters: List[TypeDeclaration] if lexer.take_token(SymbolToken(Symbol.EqualArrow)): parameters = [] else: parameters = [parse_type_declaration(lexer)] while lexer.take_token(SymbolToken(Symbol.Comma)): parameters.append(parse_type_declaration(lexer)) lexer.assert_token(SymbolToken(Symbol.EqualArrow)) return Lambda(parameters, parse_expression(lexer)) return parse_ternary(lexer) def is_valid_target(expression: Expression) -> bool: match expression: case FieldAccess(subexpression, _): return is_valid_target(subexpression) case Variable(_): return True case ArrayAccess(array, _): return is_valid_target(array) case _: assert False, ("Unimplemeneted", expression) assert False, "Unreachable" def parse_statement(lexer: Lexer) -> Statement: if lexer.take_token(KeywordToken(Keyword.Enum)): enum_name = parse_identifier(lexer) lexer.assert_token(SymbolToken(Symbol.OpenCurly)) if lexer.take_token(SymbolToken(Symbol.CloseCurly)): return EnumDefinition(enum_name, []) enum_entries: List[EnumEntry] = [parse_enum_entry(lexer)] while lexer.take_token(SymbolToken(Symbol.Comma)): enum_entries.append(parse_enum_entry(lexer)) lexer.assert_token(SymbolToken(Symbol.CloseCurly)) return EnumDefinition(enum_name, enum_entries) elif lexer.take_token(KeywordToken(Keyword.Struct)): struct_name = parse_identifier(lexer) lexer.assert_token(SymbolToken(Symbol.OpenCurly)) if lexer.take_token(SymbolToken(Symbol.CloseCurly)): return StructDefinition(struct_name, []) struct_entries: List[TypeDeclaration] = [parse_type_declaration(lexer)] while lexer.take_token(SymbolToken(Symbol.Comma)): struct_entries.append(parse_type_declaration(lexer)) lexer.assert_token(SymbolToken(Symbol.CloseCurly)) return StructDefinition(struct_name, struct_entries) elif lexer.take_token(KeywordToken(Keyword.Func)): function_name = parse_identifier(lexer) lexer.assert_token(SymbolToken(Symbol.Open)) function_arguments: List[TypeDeclaration] = [] if not lexer.take_token(SymbolToken(Symbol.Close)): function_arguments.append(parse_type_declaration(lexer)) while lexer.take_token(SymbolToken(Symbol.Comma)): function_arguments.append(parse_type_declaration(lexer)) lexer.assert_token(SymbolToken(Symbol.Close)) function_return_type: Optional[TypeExpression] = None if lexer.take_token(SymbolToken(Symbol.Arrow)): function_return_type = parse_type(lexer) function_body = parse_statement(lexer) return FunctionDefinition(function_name, function_arguments, function_return_type, function_body) elif lexer.take_token(KeywordToken(Keyword.If)): return IfStatement( parse_expression(lexer), parse_statement(lexer), parse_statement(lexer) if lexer.take_token(KeywordToken(Keyword.Else)) else None ) elif lexer.take_token(KeywordToken(Keyword.Else)): assert False, "Unmatched else" elif lexer.take_token(KeywordToken(Keyword.While)): return WhileStatement( parse_expression(lexer), parse_statement(lexer) ) elif lexer.take_token(KeywordToken(Keyword.Break)): lexer.assert_token(SymbolToken(Symbol.Semicolon)) return BreakStatement() elif lexer.take_token(KeywordToken(Keyword.Continue)): lexer.assert_token(SymbolToken(Symbol.Semicolon)) return ContinueStatement() elif lexer.take_token(KeywordToken(Keyword.Return)): expression = parse_expression(lexer) lexer.assert_token(SymbolToken(Symbol.Semicolon)) return ReturnStatement(expression) elif lexer.take_token(KeywordToken(Keyword.Do)): body = parse_statement(lexer) condition: Optional[Expression] = None if lexer.take_token(KeywordToken(Keyword.While)): condition = parse_expression(lexer) lexer.assert_token(SymbolToken(Symbol.Semicolon)) return DoWhileStatement(body, condition) elif lexer.take_token(KeywordToken(Keyword.Match)): value = parse_expression(lexer) lexer.assert_token(SymbolToken(Symbol.OpenCurly)) cases: List[Tuple[Expression, Statement]] = [] while lexer.take_token(KeywordToken(Keyword.Case)): cases.append((parse_expression(lexer), parse_statement(lexer))) lexer.assert_token(SymbolToken(Symbol.CloseCurly)) return MatchStatement(value, cases) elif lexer.take_token(KeywordToken(Keyword.Assert)): condition = parse_expression(lexer) message = parse_expression(lexer) if lexer.take_token(SymbolToken(Symbol.Comma)) else None lexer.assert_token(SymbolToken(Symbol.Semicolon)) return AssertStatement(condition, message) elif lexer.take_token(KeywordToken(Keyword.For)): variable = parse_identifier(lexer) # TODO: Allow for pattern matching here lexer.assert_token(KeywordToken(Keyword.In)) expression = parse_expression(lexer) body = parse_statement(lexer) return ForLoop(variable, expression, body) elif lexer.take_token(KeywordToken(Keyword.Import)): file = parse_string(lexer) lexer.assert_token(SymbolToken(Symbol.Semicolon)) return Import(file) elif lexer.take_token(KeywordToken(Keyword.Type)): name = parse_identifier(lexer) lexer.assert_token(SymbolToken(Symbol.Equal)) type_expression = parse_type(lexer) lexer.assert_token(SymbolToken(Symbol.Semicolon)) return TypeDefinition(name, type_expression) elif lexer.take_token(KeywordToken(Keyword.Defer)): statement = parse_statement(lexer) return DeferStatement(statement) elif lexer.check_tokenkind(KeywordToken) and not lexer.check_token(KeywordToken(Keyword.Lambda)): # TODO: Maybe use '\' for lambda instead of a keyword token = lexer.next_token() assert isinstance(token.contents, KeywordToken) raise SyntaxError(f"{token.loc}: Unexpected keyword: '{token.contents.keyword}'") elif lexer.take_token(SymbolToken(Symbol.OpenCurly)): statements: List[Statement] = [] while not lexer.take_token(SymbolToken(Symbol.CloseCurly)): statements.append(parse_statement(lexer)) return Statements(statements) else: expression = parse_expression(lexer) type: Optional[TypeExpression] = None if lexer.take_token(SymbolToken(Symbol.Colon)): assert isinstance(expression, Variable), "Cannot declare types for anything besides a variable" type = parse_type(lexer) if lexer.take_token(SymbolToken(Symbol.Equal)): assert is_valid_target(expression), ("Invalid target!", expression) right_expression = parse_expression(lexer) lexer.assert_token(SymbolToken(Symbol.Semicolon)) return Assignment(expression, right_expression, type) lexer.assert_token(SymbolToken(Symbol.Semicolon)) if type and isinstance(expression, Variable): return TypeDeclarationStatement(TypeDeclaration(expression.name, type)) return ExpressionStatement(expression)