Imagegen delay retry huggingface (#4194)
Co-authored-by: Kory Becker <kbecker@primaryobjects.com> Co-authored-by: Nicholas Tindle <nick@ntindle.com> Co-authored-by: Nicholas Tindle <nicktindle@outlook.com> Co-authored-by: k-boikov <64261260+k-boikov@users.noreply.github.com>pull/4036/head
parent
812be60d2a
commit
ee98641210
|
@ -1,6 +1,7 @@
|
|||
import functools
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
@ -106,6 +107,55 @@ def generate_and_validate(
|
|||
assert img.size == (image_size, image_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"return_text",
|
||||
[
|
||||
'{"error":"Model [model] is currently loading","estimated_time": [delay]}', # Delay
|
||||
'{"error":"Model [model] is currently loading"}', # No delay
|
||||
'{"error:}', # Bad JSON
|
||||
"", # Bad Image
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"image_model",
|
||||
["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"],
|
||||
)
|
||||
@pytest.mark.parametrize("delay", [10, 0])
|
||||
def test_huggingface_fail_request_with_delay(
|
||||
config, workspace, image_size, image_model, return_text, delay
|
||||
):
|
||||
return_text = return_text.replace("[model]", image_model).replace(
|
||||
"[delay]", str(delay)
|
||||
)
|
||||
|
||||
with patch("requests.post") as mock_post:
|
||||
if return_text == "":
|
||||
# Test bad image
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.ok = True
|
||||
mock_post.return_value.content = b"bad image"
|
||||
else:
|
||||
# Test delay and bad json
|
||||
mock_post.return_value.status_code = 500
|
||||
mock_post.return_value.ok = False
|
||||
mock_post.return_value.text = return_text
|
||||
|
||||
config.image_provider = "huggingface"
|
||||
config.huggingface_image_model = image_model
|
||||
prompt = "astronaut riding a horse"
|
||||
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
# Verify request fails.
|
||||
result = generate_image(prompt, image_size)
|
||||
assert result == "Error creating image."
|
||||
|
||||
# Verify retry was called with delay if delay is in return_text
|
||||
if "estimated_time" in return_text:
|
||||
mock_sleep.assert_called_with(delay)
|
||||
else:
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
def test_huggingface_fail_request_with_delay(mocker):
|
||||
config = Config()
|
||||
config.huggingface_api_token = "1"
|
||||
|
|
Loading…
Reference in New Issue