Imagegen delay retry huggingface (#4194)

Co-authored-by: Kory Becker <kbecker@primaryobjects.com>
Co-authored-by: Nicholas Tindle <nick@ntindle.com>
Co-authored-by: Nicholas Tindle <nicktindle@outlook.com>
Co-authored-by: k-boikov <64261260+k-boikov@users.noreply.github.com>
pull/4036/head
Luke K (pr-0f3t) 2023-05-19 13:19:39 -04:00 committed by GitHub
parent 812be60d2a
commit ee98641210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 50 additions and 0 deletions

View File

@ -1,6 +1,7 @@
import functools
import hashlib
from pathlib import Path
from unittest.mock import patch
import pytest
from PIL import Image
@ -106,6 +107,55 @@ def generate_and_validate(
assert img.size == (image_size, image_size)
@pytest.mark.parametrize(
"return_text",
[
'{"error":"Model [model] is currently loading","estimated_time": [delay]}', # Delay
'{"error":"Model [model] is currently loading"}', # No delay
'{"error:}', # Bad JSON
"", # Bad Image
],
)
@pytest.mark.parametrize(
"image_model",
["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"],
)
@pytest.mark.parametrize("delay", [10, 0])
def test_huggingface_fail_request_with_delay(
config, workspace, image_size, image_model, return_text, delay
):
return_text = return_text.replace("[model]", image_model).replace(
"[delay]", str(delay)
)
with patch("requests.post") as mock_post:
if return_text == "":
# Test bad image
mock_post.return_value.status_code = 200
mock_post.return_value.ok = True
mock_post.return_value.content = b"bad image"
else:
# Test delay and bad json
mock_post.return_value.status_code = 500
mock_post.return_value.ok = False
mock_post.return_value.text = return_text
config.image_provider = "huggingface"
config.huggingface_image_model = image_model
prompt = "astronaut riding a horse"
with patch("time.sleep") as mock_sleep:
# Verify request fails.
result = generate_image(prompt, image_size)
assert result == "Error creating image."
# Verify retry was called with delay if delay is in return_text
if "estimated_time" in return_text:
mock_sleep.assert_called_with(delay)
else:
mock_sleep.assert_not_called()
def test_huggingface_fail_request_with_delay(mocker):
config = Config()
config.huggingface_api_token = "1"