2 from sqlparse.sql import Comment, Function, Identifier, Parenthesis, \
4 from sqlparse import tokens as T
5 from .process_tokens import escape_strings, is_comment_or_space
8 def _is_directive_token(token):
9 if isinstance(token, Comment):
10 subtokens = token.tokens
12 comment = subtokens[0]
13 if comment.ttype is T.Comment.Multiline and \
14 comment.value.startswith('/*!'):
19 def is_directive_statement(statement):
20 tokens = statement.tokens
21 if not _is_directive_token(tokens[0]):
23 if tokens[-1].ttype is not T.Punctuation or tokens[-1].value != ';':
25 for token in tokens[1:-1]:
26 if token.ttype not in (T.Newline, T.Whitespace):
31 def remove_directive_tokens(statement):
32 """Remove /\*! directives \*/ from the first-level"""
34 for token in statement.tokens:
35 if _is_directive_token(token):
37 new_tokens.append(token)
38 statement.tokens = new_tokens
41 def requote_names(token_list):
42 """Remove backticks, quote non-lowercase identifiers"""
43 for token in token_list.flatten():
44 if token.ttype is T.Name:
46 if (value[0] == "`") and (value[-1] == "`"):
49 token.normalized = token.value = value
51 token.normalized = token.value = '"%s"' % value
54 def unescape_strings(token_list):
55 """Unescape strings"""
56 for token in token_list.flatten():
57 if token.ttype is T.String.Single:
70 value = value.replace(orig, repl)
71 token.normalized = token.value = value
74 def get_DML_type(statement):
75 for token in statement.tokens:
76 if is_comment_or_space(token):
78 if (token.ttype is T.DML):
79 return token.normalized
81 raise ValueError("Not a DML statement")
84 def split_ext_insert(statement):
85 """Split extended INSERT into multiple standard INSERTs"""
90 for token in statement.tokens:
91 if is_comment_or_space(token):
93 end_tokens.append(token)
95 insert_tokens.append(token)
97 elif expected == 'INSERT':
98 if (token.ttype is T.DML) and (token.normalized == 'INSERT'):
99 insert_tokens.append(token)
102 elif expected == 'INTO':
103 if (token.ttype is T.Keyword) and (token.normalized == 'INTO'):
104 insert_tokens.append(token)
105 expected = 'TABLE_NAME'
107 elif expected == 'TABLE_NAME':
108 if isinstance(token, (Function, Identifier)):
109 insert_tokens.append(token)
112 elif expected == 'VALUES':
113 if (token.ttype is T.Keyword) and (token.normalized == 'VALUES'):
114 insert_tokens.append(token)
115 expected = 'VALUES_OR_SEMICOLON'
117 elif expected == 'VALUES_OR_SEMICOLON':
118 if isinstance(token, Parenthesis):
119 values_tokens.append(token)
121 elif token.ttype is T.Punctuation:
122 if token.value == ',':
124 elif token.value == ';':
125 end_tokens.append(token)
129 'SQL syntax error: expected "%s", got %s "%s"' % (
130 expected, token.ttype, token.normalized))
131 new_line = Token(T.Newline, '\n')
132 new_lines = [new_line] # Insert newlines between split statements
133 for i, values in enumerate(values_tokens):
134 if i == len(values_tokens) - 1: # Last but one statement
135 # Insert newlines only between split statements but not after
137 # The statement sets `parent` attribute of the every token to self
139 statement = Statement(insert_tokens + [values] +
140 end_tokens + new_lines)
144 def process_statement(statement, dbname='sqlite'):
145 requote_names(statement)
146 unescape_strings(statement)
147 remove_directive_tokens(statement)
148 escape_strings(statement, dbname)
150 dml_type = get_DML_type(statement)
153 if dml_type == 'INSERT':
154 for statement in split_ext_insert(statement):