X-Git-Url: https://git.phdru.name/?a=blobdiff_plain;f=mysql2sql%2Fprocess_tokens.py;h=7f30a5589b7fad13265b4a9e147fe6308afba545;hb=196c2e0ff729c9cf772bb0e9b521ce4f4fdf84cb;hp=e9a62f825a6d5625dcd6dfc2e03b2f150208a12b;hpb=b1e2b16d0f2bbb3ba28e9d37a087e000c4dc0db5;p=sqlconvert.git diff --git a/mysql2sql/process_tokens.py b/mysql2sql/process_tokens.py index e9a62f8..7f30a55 100644 --- a/mysql2sql/process_tokens.py +++ b/mysql2sql/process_tokens.py @@ -1,19 +1,50 @@ -from sqlparse.sql import TokenList -from sqlparse.tokens import Name +from sqlparse.sql import Statement +from sqlparse.tokens import Name, Error, Punctuation, Comment, Newline, \ + Whitespace def requote_names(token_list): """Remove backticks, quote non-lowercase identifiers""" - for token in token_list: - if isinstance(token, TokenList): - requote_names(token) - else: - if token.ttype is Name: - value = token.value - if (value[0] == "`") and (value[-1] == "`"): - value = value[1:-1] - if value.islower(): - token.normalized = token.value = value - else: - token.normalized = token.value = '"%s"' % value + for token in token_list.flatten(): + if token.ttype is Name: + value = token.value + if (value[0] == "`") and (value[-1] == "`"): + value = value[1:-1] + if value.islower(): + token.normalized = token.value = value + else: + token.normalized = token.value = '"%s"' % value + + +def find_error(token_list): + """Find an error""" + for token in token_list.flatten(): + if token.ttype is Error: + return True + return False + + +class StatementGrouper(object): + def __init__(self): + self.tokens = [] + self.statements = [] + + def get_statements(self): + for statement in self.statements: + yield statement + self.statements = [] + + def process(self, tokens): + for token in tokens: + self.tokens.append(token) + if (token.ttype == Punctuation) and (token.value == ';'): + self.statements.append(Statement(self.tokens)) + self.tokens = [] + + def close(self): + for token in self.tokens: + if (token.ttype not in (Comment.Single, Comment.Multiline, + Newline, Whitespace)): + raise ValueError("Incomplete SQL statement: %s" % self.tokens) + return self.tokens