diff --git a/src/p1_mysql_auth.erl b/src/p1_mysql_auth.erl index eabb885..7e616d9 100644 --- a/src/p1_mysql_auth.erl +++ b/src/p1_mysql_auth.erl @@ -18,7 +18,8 @@ %%-------------------------------------------------------------------- -export([ do_old_auth/7, - do_new_auth/8 + do_new_auth/8, + do_new_auth/9 ]). %%-------------------------------------------------------------------- @@ -72,8 +73,27 @@ do_old_auth(Sock, RecvPid, SeqNum, User, Password, Salt1, LogFun) -> %% Returns : result of p1_mysql_conn:do_recv/3 %%-------------------------------------------------------------------- do_new_auth(Sock, RecvPid, SeqNum, User, Password, Salt1, Salt2, LogFun) -> + do_new_auth(Sock, RecvPid, SeqNum, User, Password, Salt1, Salt2, undefinded, LogFun) + . + +%%-------------------------------------------------------------------- +%% Function: do_new_auth(Sock, RecvPid, SeqNum, User, Password, Salt1, +%% Salt2, LogFun) +%% Sock = term(), gen_tcp socket +%% RecvPid = pid(), receiver process pid +%% SeqNum = integer(), first sequence number we should use +%% User = string(), MySQL username +%% Password = string(), MySQL password +%% Salt1 = string(), salt 1 from server greeting +%% Salt2 = string(), salt 2 from server greeting +%% MaxPacketSize = integer(), mysql packet size +%% LogFun = undefined | function() of arity 3 +%% Descrip.: Perform MySQL authentication. +%% Returns : result of p1_mysql_conn:do_recv/3 +%%-------------------------------------------------------------------- +do_new_auth(Sock, RecvPid, SeqNum, User, Password, Salt1, Salt2, MaxPacketSize, LogFun) -> Auth = password_new(Password, Salt1 ++ Salt2), - Packet2 = make_new_auth(User, Auth, none), + Packet2 = make_new_auth(User, Auth, none, MaxPacketSize), do_send(Sock, Packet2, SeqNum, LogFun), case p1_mysql_conn:do_recv(LogFun, RecvPid, SeqNum) of {ok, Packet3, SeqNum2} -> @@ -108,14 +128,14 @@ password_old(Password, Salt) -> make_auth(User, Password) -> Caps = ?LONG_PASSWORD bor ?LONG_FLAG bor ?TRANSACTIONS bor ?FOUND_ROWS, - Maxsize = 0, + Maxsize = ?MAX_PACKET_SIZE, UserB = list_to_binary(User), PasswordB = Password, <>. %% part of do_new_auth/4, which is part of mysql_init/4 -make_new_auth(User, Password, Database) -> +make_new_auth(User, Password, Database, MaxPacketSize) -> DBCaps = case Database of none -> 0; @@ -125,7 +145,11 @@ make_new_auth(User, Password, Database) -> Caps = ?LONG_PASSWORD bor ?LONG_FLAG bor ?TRANSACTIONS bor ?PROTOCOL_41 bor ?SECURE_CONNECTION bor DBCaps bor ?FOUND_ROWS, - Maxsize = ?MAX_PACKET_SIZE, + Maxsize = case MaxPacketSize /= undefined of + true -> MaxPacketSize; + false -> ?MAX_PACKET_SIZE; + _ -> ?MAX_PACKET_SIZE + end, UserB = list_to_binary(User), PasswordL = size(Password), DatabaseB = case Database of diff --git a/src/p1_mysql_conn.erl b/src/p1_mysql_conn.erl index 365d80b..eed80c7 100644 --- a/src/p1_mysql_conn.erl +++ b/src/p1_mysql_conn.erl @@ -63,8 +63,10 @@ %%-------------------------------------------------------------------- %% External exports %%-------------------------------------------------------------------- --export([start/6, +-export([start/7, + start/6, start_link/6, + start_link/7, fetch/3, fetch/4, squery/4, @@ -113,28 +115,61 @@ %% Reason = string() %%-------------------------------------------------------------------- start(Host, Port, User, Password, - Database, LogFun) when is_list(Host), + Database, LogFun) when is_list(Host), + is_integer(Port), + is_list(User), + is_list(Password), + is_list(Database) -> + ConnPid = self(), + Pid = spawn(fun () -> + init(Host, Port, User, Password, Database, undefined, + LogFun, ConnPid) + end), + post_start(Pid, LogFun). + +%%-------------------------------------------------------------------- +%% Function: start(Host, Port, User, Password, Database, LogFun) +%% Function: start_link(Host, Port, User, Password, Database, LogFun) +%% Host = string() +%% Port = integer() +%% User = string() +%% Password = string() +%% Database = string() +%% MaxPacketSize = integer() +%% LogFun = undefined | function() of arity 3 +%% Descrip.: Starts a p1_mysql_conn process that connects to a MySQL +%% server, logs in and chooses a database. +%% Returns : {ok, Pid} | {error, Reason} +%% Pid = pid() +%% Reason = string() +%%-------------------------------------------------------------------- +start(Host, Port, User, Password, + Database, MaxPacketSize, LogFun) when is_list(Host), is_integer(Port), is_list(User), is_list(Password), is_list(Database) -> ConnPid = self(), Pid = spawn(fun () -> - init(Host, Port, User, Password, Database, + init(Host, Port, User, Password, Database, MaxPacketSize, LogFun, ConnPid) end), post_start(Pid, LogFun). start_link(Host, Port, User, Password, - Database, LogFun) when is_list(Host), + Database, LogFun) -> + start_link(Host, Port, User, Password, Database, undefined, LogFun) +. + +start_link(Host, Port, User, Password, + Database, MaxPacketSize, LogFun) when is_list(Host), is_integer(Port), is_list(User), is_list(Password), is_list(Database) -> ConnPid = self(), Pid = spawn_link(fun () -> - init(Host, Port, User, Password, Database, - LogFun, ConnPid) + init(Host, Port, User, Password, Database, MaxPacketSize, LogFun, ConnPid) end), post_start(Pid, LogFun). @@ -288,6 +323,7 @@ do_recv(LogFun, RecvPid, SeqNum) when is_function(LogFun); %% User = string() %% Password = string() %% Database = string() +%% MaxPacketSize = integer() %% LogFun = undefined | function() of arity 3 %% Parent = pid() of process starting this p1_mysql_conn %% Descrip.: Connect to a MySQL server, log in and chooses a database. @@ -295,10 +331,10 @@ do_recv(LogFun, RecvPid, SeqNum) when is_function(LogFun); %% we were successfull. %% Returns : void() | does not return %%-------------------------------------------------------------------- -init(Host, Port, User, Password, Database, LogFun, Parent) -> +init(Host, Port, User, Password, Database, MaxPacketSize, LogFun, Parent) -> case p1_mysql_recv:start_link(Host, Port, LogFun, self()) of {ok, RecvPid, Sock} -> - case mysql_init(Sock, RecvPid, User, Password, LogFun) of + case mysql_init(Sock, RecvPid, User, Password, MaxPacketSize, LogFun) of {ok, Version} -> case do_query(Sock, RecvPid, LogFun, "use " ++ Database, Version, [{result_type, binary}]) of @@ -388,12 +424,13 @@ loop(State) -> %% RecvPid = pid(), p1_mysql_recv process %% User = string() %% Password = string() +%% MaxPacketSize = integer() %% LogFun = undefined | function() with arity 3 %% Descrip.: Try to authenticate on our new socket. %% Returns : ok | {error, Reason} %% Reason = string() %%-------------------------------------------------------------------- -mysql_init(Sock, RecvPid, User, Password, LogFun) -> +mysql_init(Sock, RecvPid, User, Password, MaxPacketSize, LogFun) -> case do_recv(LogFun, RecvPid, undefined) of {ok, Packet, InitSeqNum} -> {Version, Salt1, Salt2, Caps} = greeting(Packet, LogFun), @@ -403,7 +440,7 @@ mysql_init(Sock, RecvPid, User, Password, LogFun) -> p1_mysql_auth:do_new_auth(Sock, RecvPid, InitSeqNum + 1, User, Password, - Salt1, Salt2, LogFun); + Salt1, Salt2, MaxPacketSize, LogFun); _ -> p1_mysql_auth:do_old_auth(Sock, RecvPid, InitSeqNum + 1,