diff --git a/web_demo2.py b/web_demo2.py index 203cbdc..2c91ed3 100644 --- a/web_demo2.py +++ b/web_demo2.py @@ -45,19 +45,14 @@ def get_model(): st.markdown(query) with st.chat_message(name="assistant", avatar="assistant"): st.markdown(response) -with st.chat_message(name="user", avatar="user"): - input_placeholder = st.empty() -with st.chat_message(name="assistant", avatar="assistant"): - message_placeholder = st.empty() -prompt_text = st.text_area(label="用户命令输入", - height=100, - placeholder="请在这儿输入您的命令") - -button = st.button("发送", key="predict") +prompt_text = st.chat_input(placeholder="请在这儿输入您的命令") +if prompt_text: + with st.chat_message(name="user", avatar="user"): + st.markdown(prompt_text) + with st.chat_message(name="assistant", avatar="assistant"): + message_placeholder = st.empty() -if button: - input_placeholder.markdown(prompt_text) history, past_key_values = st.session_state.history, st.session_state.past_key_values for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history, past_key_values=past_key_values,