Match stable diffusion result's histogram to image's
This commit is contained in:
parent
0b00fffe13
commit
8e408640a4
@ -54,6 +54,7 @@ export default async function inpaint(
|
|||||||
fd.append('sdGuidanceScale', settings.sdGuidanceScale.toString())
|
fd.append('sdGuidanceScale', settings.sdGuidanceScale.toString())
|
||||||
fd.append('sdSampler', settings.sdSampler.toString())
|
fd.append('sdSampler', settings.sdSampler.toString())
|
||||||
fd.append('sdSeed', seed ? seed.toString() : '-1')
|
fd.append('sdSeed', seed ? seed.toString() : '-1')
|
||||||
|
fd.append('sdMatchHistograms', settings.sdMatchHistograms ? 'true' : 'false')
|
||||||
|
|
||||||
fd.append('cv2Radius', settings.cv2Radius.toString())
|
fd.append('cv2Radius', settings.cv2Radius.toString())
|
||||||
fd.append('cv2Flag', settings.cv2Flag.toString())
|
fd.append('cv2Flag', settings.cv2Flag.toString())
|
||||||
|
@ -120,6 +120,22 @@ const SidePanel = () => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<SettingBlock
|
||||||
|
title="Match Histograms"
|
||||||
|
input={
|
||||||
|
<Switch
|
||||||
|
checked={setting.sdMatchHistograms}
|
||||||
|
onCheckedChange={value => {
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, sdMatchHistograms: value }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<SwitchThumb />
|
||||||
|
</Switch>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
|
||||||
<SettingBlock
|
<SettingBlock
|
||||||
className="sub-setting-block"
|
className="sub-setting-block"
|
||||||
title="Sampler"
|
title="Sampler"
|
||||||
|
@ -175,6 +175,7 @@ export interface Settings {
|
|||||||
sdSeed: number
|
sdSeed: number
|
||||||
sdSeedFixed: boolean // true: use sdSeed, false: random generate seed on backend
|
sdSeedFixed: boolean // true: use sdSeed, false: random generate seed on backend
|
||||||
sdNumSamples: number
|
sdNumSamples: number
|
||||||
|
sdMatchHistograms: boolean
|
||||||
|
|
||||||
// For OpenCV2
|
// For OpenCV2
|
||||||
cv2Radius: number
|
cv2Radius: number
|
||||||
@ -278,6 +279,7 @@ export const settingStateDefault: Settings = {
|
|||||||
sdSeed: 42,
|
sdSeed: 42,
|
||||||
sdSeedFixed: true,
|
sdSeedFixed: true,
|
||||||
sdNumSamples: 1,
|
sdNumSamples: 1,
|
||||||
|
sdMatchHistograms: false,
|
||||||
|
|
||||||
// CV2
|
// CV2
|
||||||
cv2Radius: 5,
|
cv2Radius: 5,
|
||||||
|
@ -56,6 +56,9 @@ class InpaintModel:
|
|||||||
result = self.forward(pad_image, pad_mask, config)
|
result = self.forward(pad_image, pad_mask, config)
|
||||||
result = result[0:origin_height, 0:origin_width, :]
|
result = result[0:origin_height, 0:origin_width, :]
|
||||||
|
|
||||||
|
if config.sd_match_histograms:
|
||||||
|
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||||
|
|
||||||
if config.sd_mask_blur != 0:
|
if config.sd_mask_blur != 0:
|
||||||
k = 2 * config.sd_mask_blur + 1
|
k = 2 * config.sd_mask_blur + 1
|
||||||
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||||
@ -172,6 +175,44 @@ class InpaintModel:
|
|||||||
|
|
||||||
return crop_img, crop_mask, [l, t, r, b]
|
return crop_img, crop_mask, [l, t, r, b]
|
||||||
|
|
||||||
|
def _calculate_cdf(self, histogram):
|
||||||
|
cdf = histogram.cumsum()
|
||||||
|
normalized_cdf = cdf / float(cdf.max())
|
||||||
|
return normalized_cdf
|
||||||
|
|
||||||
|
def _calculate_lookup(self, source_cdf, reference_cdf):
|
||||||
|
lookup_table = np.zeros(256)
|
||||||
|
lookup_val = 0
|
||||||
|
for source_index, source_val in enumerate(source_cdf):
|
||||||
|
for reference_index, reference_val in enumerate(reference_cdf):
|
||||||
|
if reference_val >= source_val:
|
||||||
|
lookup_val = reference_index
|
||||||
|
break
|
||||||
|
lookup_table[source_index] = lookup_val
|
||||||
|
return lookup_table
|
||||||
|
|
||||||
|
def _match_histograms(self, source, reference, mask):
|
||||||
|
transformed_channels = []
|
||||||
|
for channel in range(source.shape[-1]):
|
||||||
|
source_channel = source[:, :, channel]
|
||||||
|
reference_channel = reference[:, :, channel]
|
||||||
|
|
||||||
|
# only calculate histograms for non-masked parts
|
||||||
|
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0,256])
|
||||||
|
reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0,256])
|
||||||
|
|
||||||
|
source_cdf = self._calculate_cdf(source_histogram)
|
||||||
|
reference_cdf = self._calculate_cdf(reference_histogram)
|
||||||
|
|
||||||
|
lookup = self._calculate_lookup(source_cdf, reference_cdf)
|
||||||
|
|
||||||
|
transformed_channels.append(cv2.LUT(source_channel, lookup))
|
||||||
|
|
||||||
|
result = cv2.merge(transformed_channels)
|
||||||
|
result = cv2.convertScaleAbs(result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def _run_box(self, image, mask, box, config: Config):
|
def _run_box(self, image, mask, box, config: Config):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -48,6 +48,7 @@ class Config(BaseModel):
|
|||||||
sd_sampler: str = SDSampler.ddim
|
sd_sampler: str = SDSampler.ddim
|
||||||
# -1 mean random seed
|
# -1 mean random seed
|
||||||
sd_seed: int = 42
|
sd_seed: int = 42
|
||||||
|
sd_match_histograms: bool = False
|
||||||
|
|
||||||
# cv2
|
# cv2
|
||||||
cv2_flag: str = 'INPAINT_NS'
|
cv2_flag: str = 'INPAINT_NS'
|
||||||
|
@ -131,6 +131,7 @@ def process():
|
|||||||
sd_guidance_scale=form["sdGuidanceScale"],
|
sd_guidance_scale=form["sdGuidanceScale"],
|
||||||
sd_sampler=form["sdSampler"],
|
sd_sampler=form["sdSampler"],
|
||||||
sd_seed=form["sdSeed"],
|
sd_seed=form["sdSeed"],
|
||||||
|
sd_match_histograms=form["sdMatchHistograms"],
|
||||||
cv2_flag=form["cv2Flag"],
|
cv2_flag=form["cv2Flag"],
|
||||||
cv2_radius=form['cv2Radius']
|
cv2_radius=form['cv2Radius']
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user