X-Git-Url: https://git.phdru.name/?p=sqlconvert.git;a=blobdiff_plain;f=mysql2sql%2Fprocess_tokens.py;h=1e74ac9b69a5a416bb6cf8b0418219095b73b357;hp=c395b37df0cd94bb1cb9b09df2f4b0fd21cd3583;hb=4c93c3d89685aba33fc45082022373eb93b6583e;hpb=3289238688d9c3dfefccf143d21f6c406faad9e4 diff --git a/mysql2sql/process_tokens.py b/mysql2sql/process_tokens.py index c395b37..1e74ac9 100644 --- a/mysql2sql/process_tokens.py +++ b/mysql2sql/process_tokens.py @@ -1,18 +1,74 @@ -from sqlparse.sql import TokenList -from sqlparse.tokens import Name +from sqlparse import parse +from sqlparse.compat import PY3 +from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \ + Whitespace def requote_names(token_list): """Remove backticks, quote non-lowercase identifiers""" - for token in token_list: - if isinstance(token, TokenList): - requote_names(token) - else: - if token.ttype is Name: - value = token.value - if (value[0] == "`") and (value[-1] == "`"): - value = value[1:-1] - token.normalized = token.value = value - if not value.islower(): - token.normalized = token.value = '"%s"' % value + for token in token_list.flatten(): + if token.ttype is Name: + value = token.value + if (value[0] == "`") and (value[-1] == "`"): + value = value[1:-1] + if value.islower(): + token.normalized = token.value = value + else: + token.normalized = token.value = '"%s"' % value + + +def find_error(token_list): + """Find an error""" + for token in token_list.flatten(): + if token.ttype is Error: + return True + return False + + +if PY3: + xrange = range + + +class StatementGrouper(object): + """Collect lines and reparse until the last statement is complete""" + + def __init__(self, encoding=None): + self.lines = [] + self.statements = [] + self.encoding = encoding + + def process_line(self, line): + self.lines.append(line) + self.process_lines() + + def process_lines(self): + statements = parse(''.join(self.lines), encoding=self.encoding) + last_stmt = statements[-1] + for i in xrange(len(last_stmt.tokens) - 1, 0, -1): + token = last_stmt.tokens[i] + if token.ttype in (Comment.Single, Comment.Multiline, + Newline, Whitespace): + continue + if token.ttype is Punctuation and token.value == ';': + break # The last statement is complete + # The last statement is still incomplete - wait for the next line + return + self.lines = [] + self.statements = statements + + def get_statements(self): + for stmt in self.statements: + yield stmt + self.statements = [] + + def close(self): + if not self.lines: + return + tokens = parse(''.join(self.lines), encoding=self.encoding) + for token in tokens: + if (token.ttype not in (Comment.Single, Comment.Multiline, + Newline, Whitespace)): + raise ValueError("Incomplete SQL statement: %s" % + tokens) + return tokens