from sqlparse import parse
-from sqlparse.sql import Statement
from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \
Whitespace
class StatementGrouper(object):
+ """Collect lines and reparse until the last statement is complete"""
+
def __init__(self):
- self.statements = []
- self.tokens = []
self.lines = []
+ self.statements = []
def process_line(self, line):
- lines = self.lines
- lines.append(line)
- tokens = parse('\n'.join(lines))[0]
- self.process_tokens(tokens)
- self.lines = []
+ self.lines.append(line)
+ self.process_lines()
- def process_tokens(self, tokens):
- for token in tokens:
- self.tokens.append(token)
- if (token.ttype == Punctuation) and (token.value == ';'):
- self.statements.append(Statement(self.tokens))
- self.tokens = []
+ def process_lines(self):
+ statements = parse('\n'.join(self.lines))
+ last_stmt = statements[-1]
+ for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
+ token = last_stmt.tokens[i]
+ if token.ttype in (Comment.Single, Comment.Multiline,
+ Newline, Whitespace):
+ continue
+ if token.ttype is Punctuation and token.value == ';':
+ break # The last statement is complete
+ # The last statement is still incomplete - wait for the next line
+ return
+ self.lines = []
+ self.statements = statements
def get_statements(self):
- for statement in self.statements:
- yield statement
+ for stmt in self.statements:
+ yield stmt
self.statements = []
def close(self):
- for token in self.tokens:
+ if not self.lines:
+ return
+ tokens = parse('\n'.join(self.lines))
+ for token in tokens:
if (token.ttype not in (Comment.Single, Comment.Multiline,
Newline, Whitespace)):
- raise ValueError("Incomplete SQL statement: %s" % self.tokens)
- return self.tokens
+ raise ValueError("Incomplete SQL statement: %s" %
+ tokens)
+ return tokens
query = tlist2str(statement)
self.assertEqual(query, 'SELECT * FROM "T";')
self.assertEqual(len(grouper.statements), 0)
- self.assertEqual(grouper.close(), [])
+ self.assertEqual(grouper.close(), None)
if __name__ == "__main__":
main()