X-Git-Url: https://git.phdru.name/?a=blobdiff_plain;ds=sidebyside;f=mysql2sql%2Fprocess_tokens.py;h=1e74ac9b69a5a416bb6cf8b0418219095b73b357;hb=4c93c3d89685aba33fc45082022373eb93b6583e;hp=94879cc555ea2bbbb399ecb655657dfead8ceab0;hpb=031cc0d6a41717d4c5d7c4659290e05810202eb9;p=sqlconvert.git diff --git a/mysql2sql/process_tokens.py b/mysql2sql/process_tokens.py index 94879cc..1e74ac9 100644 --- a/mysql2sql/process_tokens.py +++ b/mysql2sql/process_tokens.py @@ -1,6 +1,8 @@ -from sqlparse.sql import Statement -from sqlparse.tokens import Name, Error, Punctuation +from sqlparse import parse +from sqlparse.compat import PY3 +from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \ + Whitespace def requote_names(token_list): @@ -24,23 +26,49 @@ def find_error(token_list): return False +if PY3: + xrange = range + + class StatementGrouper(object): - def __init__(self): - self.tokens = [] + """Collect lines and reparse until the last statement is complete""" + + def __init__(self, encoding=None): + self.lines = [] self.statements = [] + self.encoding = encoding + + def process_line(self, line): + self.lines.append(line) + self.process_lines() + + def process_lines(self): + statements = parse(''.join(self.lines), encoding=self.encoding) + last_stmt = statements[-1] + for i in xrange(len(last_stmt.tokens) - 1, 0, -1): + token = last_stmt.tokens[i] + if token.ttype in (Comment.Single, Comment.Multiline, + Newline, Whitespace): + continue + if token.ttype is Punctuation and token.value == ';': + break # The last statement is complete + # The last statement is still incomplete - wait for the next line + return + self.lines = [] + self.statements = statements def get_statements(self): - for statement in self.statements: - yield statement + for stmt in self.statements: + yield stmt self.statements = [] - def process(self, tokens): - for token in tokens: - self.tokens.append(token) - if (token.ttype == Punctuation) and (token.value == ';'): - self.statements.append(Statement(self.tokens)) - self.tokens = [] - def close(self): - if self.tokens: - raise ValueError("Incomplete SQL statement") + if not self.lines: + return + tokens = parse(''.join(self.lines), encoding=self.encoding) + for token in tokens: + if (token.ttype not in (Comment.Single, Comment.Multiline, + Newline, Whitespace)): + raise ValueError("Incomplete SQL statement: %s" % + tokens) + return tokens