X-Git-Url: https://git.phdru.name/?p=sqlconvert.git;a=blobdiff_plain;f=sqlconvert%2Fprocess_tokens.py;h=cb14467ccf958b478d888714ebaa6240e10fede2;hp=0bbf94cb703ecb8963d93eeac147ecff0c19f1d4;hb=HEAD;hpb=dbc9220a2b29725f94637607f8d8b00c762deb67 diff --git a/sqlconvert/process_tokens.py b/sqlconvert/process_tokens.py index 0bbf94c..886f36d 100644 --- a/sqlconvert/process_tokens.py +++ b/sqlconvert/process_tokens.py @@ -1,19 +1,43 @@ +from sqlparse.sql import Comment +from sqlobject.converters import sqlrepr from sqlparse import parse -from sqlparse.compat import PY3 -from sqlparse.tokens import Error, Punctuation, Comment, Newline, Whitespace +from sqlparse import tokens as T + +try: + xrange +except NameError: + xrange = range def find_error(token_list): """Find an error""" for token in token_list.flatten(): - if token.ttype is Error: + if token.ttype is T.Error: return True return False -if PY3: - xrange = range +def is_comment_or_space(token): + return isinstance(token, Comment) or \ + token.ttype in (T.Comment, T.Comment.Single, T.Comment.Multiline, + T.Newline, T.Whitespace) + + +def is_newline_statement(statement): + for token in statement.tokens[:]: + if token.ttype is not T.Newline: + return False + return True + + +def escape_strings(token_list, dbname): + """Escape strings""" + for token in token_list.flatten(): + if token.ttype is T.String.Single: + value = token.value[1:-1] # unquote by removing apostrophes + value = sqlrepr(value, dbname) + token.normalized = token.value = value class StatementGrouper(object): @@ -30,13 +54,14 @@ class StatementGrouper(object): def process_lines(self): statements = parse(''.join(self.lines), encoding=self.encoding) + if not statements: + return last_stmt = statements[-1] for i in xrange(len(last_stmt.tokens) - 1, 0, -1): token = last_stmt.tokens[i] - if token.ttype in (Comment.Single, Comment.Multiline, - Newline, Whitespace): + if is_comment_or_space(token): continue - if token.ttype is Punctuation and token.value == ';': + if token.ttype is T.Punctuation and token.value == ';': break # The last statement is complete # The last statement is still incomplete - wait for the next line return @@ -47,14 +72,16 @@ class StatementGrouper(object): for stmt in self.statements: yield stmt self.statements = [] + return def close(self): if not self.lines: return tokens = parse(''.join(self.lines), encoding=self.encoding) for token in tokens: - if (token.ttype not in (Comment.Single, Comment.Multiline, - Newline, Whitespace)): + if not is_comment_or_space(token): raise ValueError("Incomplete SQL statement: %s" % tokens) + self.lines = [] + self.statements = [] return tokens