2 from sqlparse.sql import Comment, Function, Identifier, Parenthesis, \
4 from sqlparse import tokens as T
5 from .process_tokens import escape_strings, is_comment_or_space
8 def _is_directive_token(token):
9 if isinstance(token, Comment):
10 subtokens = token.tokens
12 comment = subtokens[0]
13 if comment.ttype is T.Comment.Multiline and \
14 comment.value.startswith('/*!'):
19 def is_directive_statement(statement):
20 tokens = statement.tokens
21 if not _is_directive_token(tokens[0]):
23 if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';':
25 for token in tokens[1:-1]:
26 if token.ttype not in (T.Newline, T.Whitespace):
31 def remove_directive_tokens(statement):
32 """Remove /\*! directives \*/ from the first-level"""
34 for token in statement.tokens:
35 if _is_directive_token(token):
37 new_tokens.append(token)
38 statement.tokens = new_tokens
41 def requote_names(token_list):
42 """Remove backticks, quote non-lowercase identifiers"""
43 for token in token_list.flatten():
44 if token.ttype is T.Name:
46 if (value[0] == "`") and (value[-1] == "`"):
49 token.normalized = token.value = value
51 token.normalized = token.value = '"%s"' % value
54 def unescape_strings(token_list):
55 """Unescape strings"""
56 for token in token_list.flatten():
57 if token.ttype is T.String.Single:
70 value = value.replace(orig, repl)
71 token.normalized = token.value = value
74 def is_insert(statement):
75 for token in statement.tokens:
76 if is_comment_or_space(token):
78 return (token.ttype is T.DML) and (token.normalized == 'INSERT')
81 def split_ext_insert(statement):
82 """Split extended INSERT into multiple standard INSERTs"""
87 for token in statement.tokens:
88 if is_comment_or_space(token):
90 end_tokens.append(token)
92 insert_tokens.append(token)
94 elif expected == 'INSERT':
95 if (token.ttype is T.DML) and (token.normalized == 'INSERT'):
96 insert_tokens.append(token)
99 elif expected == 'INTO':
100 if (token.ttype is T.Keyword) and (token.normalized == 'INTO'):
101 insert_tokens.append(token)
102 expected = 'TABLE_NAME'
104 elif expected == 'TABLE_NAME':
105 if isinstance(token, (Function, Identifier)):
106 insert_tokens.append(token)
109 elif expected == 'VALUES':
110 if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'):
111 insert_tokens.append(token)
112 expected = 'VALUES_OR_SEMICOLON'
114 elif expected == 'VALUES_OR_SEMICOLON':
115 if isinstance(token, Parenthesis):
116 values_tokens.append(token)
118 elif token.ttype is T.Punctuation:
119 if token.value == ',':
121 elif token.value == ';':
122 end_tokens.append(token)
126 'SQL syntax error: expected "%s", got %s "%s"' % (
127 expected, token.ttype, token.normalized))
128 new_line = Token(T.Newline, '\n')
129 new_lines = [new_line] # Insert newlines between split statements
130 for i, values in enumerate(values_tokens):
131 if i == len(values_tokens) - 1: # Last but one statement
132 # Insert newlines only between split statements but not after
134 # The statemnt sets `parent` attribute of the every token to self
136 statement = Statement(insert_tokens + [values] +
137 end_tokens + new_lines)
141 def process_statement(statement, quoting_style='sqlite'):
142 requote_names(statement)
143 unescape_strings(statement)
144 remove_directive_tokens(statement)
145 escape_strings(statement, quoting_style)
146 if is_insert(statement):
147 for statement in split_ext_insert(statement):