From 0f209248832190291a510febd790e71a815e496a Mon Sep 17 00:00:00 2001 From: Oleg Broytman Date: Sat, 27 Aug 2016 00:48:02 +0300 Subject: [PATCH] Collect lines and reparse until the last statement is complete --- mysql2sql/process_tokens.py | 47 ++++++++++++++++++++++--------------- scripts/group-file.py | 7 +++--- scripts/group-sql.py | 7 +++--- tests/test_stgrouper.py | 2 +- 4 files changed, 37 insertions(+), 26 deletions(-) diff --git a/mysql2sql/process_tokens.py b/mysql2sql/process_tokens.py index 37752e7..ddab2bf 100644 --- a/mysql2sql/process_tokens.py +++ b/mysql2sql/process_tokens.py @@ -1,6 +1,5 @@ from sqlparse import parse -from sqlparse.sql import Statement from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \ Whitespace @@ -27,33 +26,43 @@ def find_error(token_list): class StatementGrouper(object): + """Collect lines and reparse until the last statement is complete""" + def __init__(self): - self.statements = [] - self.tokens = [] self.lines = [] + self.statements = [] def process_line(self, line): - lines = self.lines - lines.append(line) - tokens = parse('\n'.join(lines))[0] - self.process_tokens(tokens) - self.lines = [] + self.lines.append(line) + self.process_lines() - def process_tokens(self, tokens): - for token in tokens: - self.tokens.append(token) - if (token.ttype == Punctuation) and (token.value == ';'): - self.statements.append(Statement(self.tokens)) - self.tokens = [] + def process_lines(self): + statements = parse('\n'.join(self.lines)) + last_stmt = statements[-1] + for i in xrange(len(last_stmt.tokens) - 1, 0, -1): + token = last_stmt.tokens[i] + if token.ttype in (Comment.Single, Comment.Multiline, + Newline, Whitespace): + continue + if token.ttype is Punctuation and token.value == ';': + break # The last statement is complete + # The last statement is still incomplete - wait for the next line + return + self.lines = [] + self.statements = statements def get_statements(self): - for statement in self.statements: - yield statement + for stmt in self.statements: + yield stmt self.statements = [] def close(self): - for token in self.tokens: + if not self.lines: + return + tokens = parse('\n'.join(self.lines)) + for token in tokens: if (token.ttype not in (Comment.Single, Comment.Multiline, Newline, Whitespace)): - raise ValueError("Incomplete SQL statement: %s" % self.tokens) - return self.tokens + raise ValueError("Incomplete SQL statement: %s" % + tokens) + return tokens diff --git a/scripts/group-file.py b/scripts/group-file.py index 92f8fb9..5ab1f44 100755 --- a/scripts/group-file.py +++ b/scripts/group-file.py @@ -23,9 +23,10 @@ def main(filename): statement._pprint_tree() print("----------") tokens = grouper.close() - for token in tokens: - print_tokens(token) - print(repr(token)) + if tokens: + for token in tokens: + print_tokens(token) + print(repr(token)) if __name__ == '__main__': diff --git a/scripts/group-sql.py b/scripts/group-sql.py index 850dcb3..953e9e7 100755 --- a/scripts/group-sql.py +++ b/scripts/group-sql.py @@ -22,9 +22,10 @@ def main(*queries): statement._pprint_tree() print("----------") tokens = grouper.close() - for token in tokens: - print_tokens(token) - print(repr(token)) + if tokens: + for token in tokens: + print_tokens(token) + print(repr(token)) def test(): diff --git a/tests/test_stgrouper.py b/tests/test_stgrouper.py index bd60168..d77773e 100755 --- a/tests/test_stgrouper.py +++ b/tests/test_stgrouper.py @@ -26,7 +26,7 @@ class TestStGrouper(unittest.TestCase): query = tlist2str(statement) self.assertEqual(query, 'SELECT * FROM "T";') self.assertEqual(len(grouper.statements), 0) - self.assertEqual(grouper.close(), []) + self.assertEqual(grouper.close(), None) if __name__ == "__main__": main() -- 2.39.5