From: Oleg Broytman Date: Wed, 24 Aug 2016 14:44:41 +0000 (+0300) Subject: Process input stream line by line X-Git-Tag: 0.0.1~21 X-Git-Url: https://git.phdru.name/?a=commitdiff_plain;h=96d6304e4c08ef94bb7dfac3e3069fcb8c982747;p=sqlconvert.git Process input stream line by line --- diff --git a/mysql2sql/process_tokens.py b/mysql2sql/process_tokens.py index 7f30a55..37752e7 100644 --- a/mysql2sql/process_tokens.py +++ b/mysql2sql/process_tokens.py @@ -1,4 +1,5 @@ +from sqlparse import parse from sqlparse.sql import Statement from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \ Whitespace @@ -27,21 +28,29 @@ def find_error(token_list): class StatementGrouper(object): def __init__(self): - self.tokens = [] self.statements = [] + self.tokens = [] + self.lines = [] - def get_statements(self): - for statement in self.statements: - yield statement - self.statements = [] + def process_line(self, line): + lines = self.lines + lines.append(line) + tokens = parse('\n'.join(lines))[0] + self.process_tokens(tokens) + self.lines = [] - def process(self, tokens): + def process_tokens(self, tokens): for token in tokens: self.tokens.append(token) if (token.ttype == Punctuation) and (token.value == ';'): self.statements.append(Statement(self.tokens)) self.tokens = [] + def get_statements(self): + for statement in self.statements: + yield statement + self.statements = [] + def close(self): for token in self.tokens: if (token.ttype not in (Comment.Single, Comment.Multiline, diff --git a/scripts/group-file.py b/scripts/group-file.py index 53a11f1..92f8fb9 100755 --- a/scripts/group-file.py +++ b/scripts/group-file.py @@ -2,7 +2,6 @@ from __future__ import print_function import sys -from sqlparse import parse from mysql2sql.print_tokens import print_tokens from mysql2sql.process_tokens import requote_names, find_error, \ StatementGrouper @@ -12,7 +11,7 @@ def main(filename): grouper = StatementGrouper() with open(filename) as infile: for line in infile: - grouper.process(parse(line)[0]) + grouper.process_line(line) if grouper.statements: for statement in grouper.get_statements(): print("----------") diff --git a/scripts/group-sql.py b/scripts/group-sql.py index 0eab6b4..850dcb3 100755 --- a/scripts/group-sql.py +++ b/scripts/group-sql.py @@ -2,7 +2,6 @@ from __future__ import print_function import sys -from sqlparse import parse from mysql2sql.print_tokens import print_tokens from mysql2sql.process_tokens import requote_names, find_error, \ StatementGrouper @@ -11,7 +10,7 @@ from mysql2sql.process_tokens import requote_names, find_error, \ def main(*queries): grouper = StatementGrouper() for query in queries: - grouper.process(parse(query)[0]) + grouper.process_line(query) if grouper.statements: for statement in grouper.get_statements(): print("----------") diff --git a/tests/test_stgrouper.py b/tests/test_stgrouper.py index cb4f88f..bd60168 100755 --- a/tests/test_stgrouper.py +++ b/tests/test_stgrouper.py @@ -2,7 +2,6 @@ import unittest -from sqlparse import parse from mysql2sql.print_tokens import tlist2str from mysql2sql.process_tokens import requote_names, StatementGrouper @@ -12,16 +11,14 @@ from tests import main class TestStGrouper(unittest.TestCase): def test_incomplete(self): grouper = StatementGrouper() - parsed = parse("select * from `T`")[0] - grouper.process(parsed) + grouper.process_line("select * from `T`") self.assertFalse(grouper.statements) self.assertEqual(len(grouper.statements), 0) self.assertRaises(ValueError, grouper.close) def test_statements(self): grouper = StatementGrouper() - parsed = parse("select * from `T`;")[0] - grouper.process(parsed) + grouper.process_line("select * from `T`;") self.assertTrue(grouper.statements) self.assertEqual(len(grouper.statements), 1) for statement in grouper.get_statements():