-Fix semicolons and newlines after /*! directives */
-
-
Convert string escapes to generic SQL, Postgres- or SQLite-specific.
from sqlparse.compat import text_type
from sqlconvert.print_tokens import print_tokens
-from sqlconvert.process_mysql import process_statement
-from sqlconvert.process_tokens import StatementGrouper
+from sqlconvert.process_mysql import is_directive_statement, process_statement
+from sqlconvert.process_tokens import is_newline_statement, StatementGrouper
from m_lib.defenc import default_encoding
from m_lib.pbar.tty_pbar import ttyProgressBar
cur_pos = 0
grouper = StatementGrouper(encoding=encoding)
+ got_directive = False
for line in infile:
if use_pbar:
if isinstance(line, text_type):
grouper.process_line(line)
if grouper.statements:
for statement in grouper.get_statements():
+ if got_directive and is_newline_statement(statement):
+ # Replace a sequence of newlines after a /*! directive */;
+ # with one newline
+ #outfile.write(u'\n')
+ continue
+ got_directive = is_directive_statement(statement)
+ if got_directive:
+ continue
process_statement(statement)
print_tokens(statement, outfile=outfile,
encoding=output_encoding)
from sqlparse import tokens as T
+def _is_directive_token(token):
+ if isinstance(token, Comment):
+ subtokens = token.tokens
+ if subtokens:
+ comment = subtokens[0]
+ if comment.ttype is T.Comment.Multiline and \
+ comment.value.startswith('/*!'):
+ return True
+ return False
+
+
+def is_directive_statement(statement):
+ tokens = statement.tokens
+ if not _is_directive_token(tokens[0]):
+ return False
+ if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';':
+ return False
+ for token in tokens[1:-1]:
+ if token.ttype not in (T.Newline, T.Whitespace):
+ return False
+ return True
+
+
+def remove_directives(statement):
+ """Remove /*! directives */ from the first-level"""
+ new_tokens = []
+ for token in statement.tokens:
+ if _is_directive_token(token):
+ continue
+ new_tokens.append(token)
+ statement.tokens = new_tokens
+
+
def requote_names(token_list):
"""Remove backticks, quote non-lowercase identifiers"""
for token in token_list.flatten():
token.normalized = token.value = '"%s"' % value
-def remove_directives(statement):
- """Remove /*! directives */ from the first-level"""
- new_tokens = []
- for token in statement.tokens:
- if isinstance(token, Comment):
- subtokens = token.tokens
- if subtokens:
- comment = subtokens[0]
- if comment.ttype is T.Comment.Multiline and \
- comment.value.startswith('/*!'):
- continue
- new_tokens.append(token)
- statement.tokens = new_tokens
-
-
def process_statement(statement):
- requote_names(statement)
remove_directives(statement)
+ requote_names(statement)
from sqlparse import parse
from sqlparse.compat import PY3
-from sqlparse.tokens import Error, Punctuation, Comment, Newline, Whitespace
+from sqlparse import tokens as T
def find_error(token_list):
"""Find an error"""
for token in token_list.flatten():
- if token.ttype is Error:
+ if token.ttype is T.Error:
return True
return False
+def is_newline_statement(statement):
+ for token in statement.tokens[:]:
+ if token.ttype is not T.Newline:
+ return False
+ return True
+
+
if PY3:
xrange = range
last_stmt = statements[-1]
for i in xrange(len(last_stmt.tokens) - 1, 0, -1):
token = last_stmt.tokens[i]
- if token.ttype in (Comment.Single, Comment.Multiline,
- Newline, Whitespace):
+ if token.ttype in (T.Comment.Single, T.Comment.Multiline,
+ T.Newline, T.Whitespace):
continue
- if token.ttype is Punctuation and token.value == ';':
+ if token.ttype is T.Punctuation and token.value == ';':
break # The last statement is complete
# The last statement is still incomplete - wait for the next line
return
return
tokens = parse(''.join(self.lines), encoding=self.encoding)
for token in tokens:
- if (token.ttype not in (Comment.Single, Comment.Multiline,
- Newline, Whitespace)):
+ if (token.ttype not in (T.Comment.Single, T.Comment.Multiline,
+ T.Newline, T.Whitespace)):
raise ValueError("Incomplete SQL statement: %s" %
tokens)
return tokens
from sqlparse import parse
from sqlconvert.print_tokens import tlist2str
-from sqlconvert.process_mysql import requote_names, remove_directives, \
- process_statement
+from sqlconvert.process_mysql import remove_directives, requote_names, \
+ is_directive_statement, process_statement
from tests import main
query = tlist2str(parsed)
self.assertEqual(query, 'SELECT * FROM "T"')
- def test_directives(self):
+ def test_directive(self):
parsed = parse("select /*! test */ * from /* test */ `T`")[0]
remove_directives(parsed)
query = tlist2str(parsed)
self.assertEqual(query, 'SELECT * FROM /* test */ `T`')
+ def test_directive_statement(self):
+ parsed = parse("/*! test */ test ;")[0]
+ self.assertFalse(is_directive_statement(parsed))
+ parsed = parse("/*! test */ ;")[0]
+ self.assertTrue(is_directive_statement(parsed))
+
def test_process(self):
parsed = parse("select /*! test */ * from /* test */ `T`")[0]
process_statement(parsed)