choiseungmi

Upload our codec

# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import struct
import sys
import time
import math
from pathlib import Path
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import ToPILImage, ToTensor
import compressai
from compressai.zoo import models
model_ids = {k: i for i, k in enumerate(models.keys())}
metric_ids = {
"mse": 0,
}
def inverse_dict(d):
# We assume dict values are unique...
assert len(d.keys()) == len(set(d.keys()))
return {v: k for k, v in d.items()}
def filesize(filepath: str) -> int:
if not Path(filepath).is_file():
raise ValueError(f'Invalid file "{filepath}".')
return Path(filepath).stat().st_size
def load_image(filepath: str) -> Image.Image:
return Image.open(filepath).convert("RGB")
def img2torch(img: Image.Image) -> torch.Tensor:
return ToTensor()(img).unsqueeze(0)
def torch2img(x: torch.Tensor) -> Image.Image:
return ToPILImage()(x.clamp_(0, 1).squeeze())
def write_uints(fd, values, fmt=">{:d}I"):
fd.write(struct.pack(fmt.format(len(values)), *values))
def write_uchars(fd, values, fmt=">{:d}B"):
fd.write(struct.pack(fmt.format(len(values)), *values))
def read_uints(fd, n, fmt=">{:d}I"):
sz = struct.calcsize("I")
return struct.unpack(fmt.format(n), fd.read(n * sz))
def read_uchars(fd, n, fmt=">{:d}B"):
sz = struct.calcsize("B")
return struct.unpack(fmt.format(n), fd.read(n * sz))
def write_bytes(fd, values, fmt=">{:d}s"):
if len(values) == 0:
return
fd.write(struct.pack(fmt.format(len(values)), values))
def read_bytes(fd, n, fmt=">{:d}s"):
sz = struct.calcsize("s")
return struct.unpack(fmt.format(n), fd.read(n * sz))[0]
def get_header(model_name, metric, quality):
"""Format header information:
- 1 byte for model id
- 4 bits for metric
- 4 bits for quality param
"""
metric = metric_ids[metric]
code = (metric << 4) | (quality - 1 & 0x0F)
return model_ids[model_name], code
def parse_header(header):
"""Read header information from 2 bytes:
- 1 byte for model id
- 4 bits for metric
- 4 bits for quality param
"""
model_id, code = header
quality = (code & 0x0F) + 1
metric = code >> 4
return (
inverse_dict(model_ids)[model_id],
inverse_dict(metric_ids)[metric],
quality,
)
def pad(x, p=2 ** 6):
h, w = x.size(2), x.size(3)
H = (h + p - 1) // p * p
W = (w + p - 1) // p * p
padding_left = (W - w) // 2
padding_right = W - w - padding_left
padding_top = (H - h) // 2
padding_bottom = H - h - padding_top
return F.pad(
x,
(padding_left, padding_right, padding_top, padding_bottom),
mode="constant",
value=0,
)
def crop(x, size):
H, W = x.size(2), x.size(3)
h, w = size
padding_left = (W - w) // 2
padding_right = W - w - padding_left
padding_top = (H - h) // 2
padding_bottom = H - h - padding_top
return F.pad(
x,
(-padding_left, -padding_right, -padding_top, -padding_bottom),
mode="constant",
value=0,
)
def compute_psnr(a, b):
mse = torch.mean((a - b)**2).item()
return -10 * math.log10(mse)
def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path):
compressai.set_entropy_coder(coder)
enc_start = time.time()
img = load_image(image)
start = time.time()
net = models[model](quality=quality, metric=metric, pretrained=True).eval()
load_time = time.time() - start
x = img2torch(img)
h, w = x.size(2), x.size(3)
p = 64 # maximum 6 strides of 2
x = pad(x, p)
# header = get_header(model, metric, quality)
strings = []
with torch.no_grad():
out = net.compress(x)
shape = out["shape"]
with Path(output).open("ab") as f:
# write shape and number of encoded latents
write_uints(f, (shape[0], shape[1], len(out["strings"])))
for s in out["strings"]:
write_uints(f, (len(s[0]),))
write_bytes(f, s[0])
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
x_recon = crop(recon_out["x_hat"], (h, w))
psnr=compute_psnr(x, x_recon)
if i==False:
diff=x-ref
diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5
diff_img = torch2img(diff1)
diff_img.save(path+"recon/diff_v1_"+str(ff)+"_q"+str(quality)+".png")
enc_time = time.time() - enc_start
size = filesize(output)
bpp = float(size) * 8 / (img.size[0] * img.size[1]*3)
with Path(log_path).open("a") as f:
f.write( f" {bpp-total_bpp:.4f} | "
f" {psnr:.4f} |"
f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n")
recon_img = torch2img(x_recon)
recon_img.save(path+"recon/v1_recon"+str(ff)+"_q"+str(quality)+".png")
return psnr, bpp, x_recon, enc_time
def _decode(inputpath, coder, show, frame, output=None):
compressai.set_entropy_coder(coder)
dec_start = time.time()
with Path(inputpath).open("rb") as f:
model, metric, quality = parse_header(read_uchars(f, 2))
print(f"Model: {model:s}, metric: {metric:s}, quality: {quality:d}")
for i in range(frame):
original_size = read_uints(f, 2)
shape = read_uints(f, 2)
strings = []
n_strings = read_uints(f, 1)[0]
for _ in range(n_strings):
s = read_bytes(f, read_uints(f, 1)[0])
strings.append([s])
start = time.time()
net = models[model](quality=quality, metric=metric, pretrained=True).eval()
load_time = time.time() - start
with torch.no_grad():
out = net.decompress(strings, shape)
x_hat = crop(out["x_hat"], original_size)
img = torch2img(x_hat)
dec_time = time.time() - dec_start
print(f"Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)")
if show:
show_image(img)
if output is not None:
img.save(output+"_frame"+str(i)+".png")
def show_image(img: Image.Image):
from matplotlib import pyplot as plt
fig, ax = plt.subplots()
ax.axis("off")
ax.title.set_text("Decoded image")
ax.imshow(img)
fig.tight_layout()
plt.show()
def encode(argv):
parser = argparse.ArgumentParser(description="Encode image to bit-stream")
parser.add_argument("image", type=str)
parser.add_argument(
"--model",
choices=models.keys(),
default=list(models.keys())[0],
help="NN model to use (default: %(default)s)",
)
parser.add_argument(
"-m",
"--metric",
choices=["mse"],
default="mse",
help="metric trained against (default: %(default)s",
)
parser.add_argument(
"-q",
"--quality",
choices=list(range(1, 9)),
type=int,
default=3,
help="Quality setting (default: %(default)s)",
)
parser.add_argument(
"-c",
"--coder",
choices=compressai.available_entropy_coders(),
default=compressai.available_entropy_coders()[0],
help="Entropy coder (default: %(default)s)",
)
parser.add_argument(
"-f",
"--frame",
type=int,
default=100,
help="Frame setting (default: %(default)s)",
)
parser.add_argument(
"-fr",
"--framerate",
choices=[60,50,24],
type=int,
default=50,
help="Frame rate setting (default: %(default)s)",
)
parser.add_argument(
"-width",
"--width",
type=int,
default=768,
help="width setting (default: %(default))",
)
parser.add_argument(
"-hight",
"--hight",
type=int,
default=768,
help="hight setting (default: %(default))",
)
parser.add_argument("-o", "--output", help="Output path")
args = parser.parse_args(argv)
path="examples/"+args.image+"/"
if not args.output:
#args.output = Path(Path(args.image).resolve().name).with_suffix(".bin")
args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v1.bin"
log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v1.txt"
header = get_header(args.model, args.metric, args.quality)
with Path(args.output).open("wb") as f:
write_uchars(f, header)
write_uints(f, (args.hight, args.width))
with Path(log_path).open("w") as f:
f.write(f"model : {args.model} | "
f"quality : {args.quality} | "
f"frames : {args.frame}\n")
f.write( f"frame | bpp | "
f" psnr |"
f" Encoded time (model loading)\n"
f" {0:3d} | ")
total_psnr=0.0
total_bpp=0.0
total_time=0.0
args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
img=args.image+"_frame"+str(0)+".png"
total_psnr, total_bpp, ref,total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
for ff in range(1, args.frame):
with Path(log_path).open("a") as f:
f.write(f" {ff:3d} | ")
img=args.image+"_frame"+str(ff)+".png"
psnr, total_bpp, ref,time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
total_psnr+=psnr
total_time+=time
total_psnr/=args.frame
total_bpp/=args.frame
with Path(log_path).open("a") as f:
f.write( f"\n Total Encoded time: {total_time:.2f}s\n"
f"\n Total PSNR: {total_psnr:.6f}\n"
f" Total BPP: {total_bpp:.6f}\n")
print(total_psnr)
print(total_bpp)
def decode(argv):
parser = argparse.ArgumentParser(description="Decode bit-stream to imager")
parser.add_argument("input", type=str)
parser.add_argument(
"-c",
"--coder",
choices=compressai.available_entropy_coders(),
default=compressai.available_entropy_coders()[0],
help="Entropy coder (default: %(default)s)",
)
parser.add_argument(
"-f",
"--frame",
choices=list(range(1, 600)),
type=int,
default=100,
help="Frame setting (default: %(default)s)",
)
parser.add_argument("--show", action="store_true")
parser.add_argument("-o", "--output", help="Output path")
args = parser.parse_args(argv)
args.input="examples/"+args.input+"/"+args.input+"_768x768_"+str(args.frame//2)+"_8bit_444.bin"
args.output="examples/recon/"+args.output+"/"+args.output+"_768x768_"+str(50)+"_8bit_444"
_decode(args.input, args.coder, args.show, args.frame, args.output)
def parse_args(argv):
parser = argparse.ArgumentParser(description="")
parser.add_argument("command", choices=["encode", "decode"])
args = parser.parse_args(argv)
return args
def main(argv):
args = parse_args(argv[1:2])
argv = argv[2:]
torch.set_num_threads(1) # just to be sure
if args.command == "encode":
encode(argv)
elif args.command == "decode":
decode(argv)
if __name__ == "__main__":
main(sys.argv)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import struct
import sys
import time
import math
from pathlib import Path
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import ToPILImage, ToTensor
import compressai
from compressai.zoo import models
model_ids = {k: i for i, k in enumerate(models.keys())}
metric_ids = {
"mse": 0,
}
def inverse_dict(d):
# We assume dict values are unique...
assert len(d.keys()) == len(set(d.keys()))
return {v: k for k, v in d.items()}
def filesize(filepath: str) -> int:
if not Path(filepath).is_file():
raise ValueError(f'Invalid file "{filepath}".')
return Path(filepath).stat().st_size
def load_image(filepath: str) -> Image.Image:
return Image.open(filepath).convert("RGB")
def img2torch(img: Image.Image) -> torch.Tensor:
return ToTensor()(img).unsqueeze(0)
def torch2img(x: torch.Tensor) -> Image.Image:
return ToPILImage()(x.clamp_(0, 1).squeeze())
def write_uints(fd, values, fmt=">{:d}I"):
fd.write(struct.pack(fmt.format(len(values)), *values))
def write_uchars(fd, values, fmt=">{:d}B"):
fd.write(struct.pack(fmt.format(len(values)), *values))
def read_uints(fd, n, fmt=">{:d}I"):
sz = struct.calcsize("I")
return struct.unpack(fmt.format(n), fd.read(n * sz))
def read_uchars(fd, n, fmt=">{:d}B"):
sz = struct.calcsize("B")
return struct.unpack(fmt.format(n), fd.read(n * sz))
def write_bytes(fd, values, fmt=">{:d}s"):
if len(values) == 0:
return
fd.write(struct.pack(fmt.format(len(values)), values))
def read_bytes(fd, n, fmt=">{:d}s"):
sz = struct.calcsize("s")
return struct.unpack(fmt.format(n), fd.read(n * sz))[0]
def get_header(model_name, metric, quality):
"""Format header information:
- 1 byte for model id
- 4 bits for metric
- 4 bits for quality param
"""
metric = metric_ids[metric]
code = (metric << 4) | (quality - 1 & 0x0F)
return model_ids[model_name], code
def parse_header(header):
"""Read header information from 2 bytes:
- 1 byte for model id
- 4 bits for metric
- 4 bits for quality param
"""
model_id, code = header
quality = (code & 0x0F) + 1
metric = code >> 4
return (
inverse_dict(model_ids)[model_id],
inverse_dict(metric_ids)[metric],
quality,
)
def pad(x, p=2 ** 6):
h, w = x.size(2), x.size(3)
H = (h + p - 1) // p * p
W = (w + p - 1) // p * p
padding_left = (W - w) // 2
padding_right = W - w - padding_left
padding_top = (H - h) // 2
padding_bottom = H - h - padding_top
return F.pad(
x,
(padding_left, padding_right, padding_top, padding_bottom),
mode="constant",
value=0,
)
def crop(x, size):
H, W = x.size(2), x.size(3)
h, w = size
padding_left = (W - w) // 2
padding_right = W - w - padding_left
padding_top = (H - h) // 2
padding_bottom = H - h - padding_top
return F.pad(
x,
(-padding_left, -padding_right, -padding_top, -padding_bottom),
mode="constant",
value=0,
)
def compute_psnr(a, b):
mse = torch.mean((a - b)**2).item()
return -10 * math.log10(mse)
def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path):
compressai.set_entropy_coder(coder)
enc_start = time.time()
img = load_image(image)
start = time.time()
net = models[model](quality=quality, metric=metric, pretrained=True).eval()
load_time = time.time() - start
x = img2torch(img)
h, w = x.size(2), x.size(3)
p = 64 # maximum 6 strides of 2
x = pad(x, p)
# header = get_header(model, metric, quality)
if i==True:
strings = []
with torch.no_grad():
out = net.compress(x)
shape = out["shape"]
with Path(output).open("ab") as f:
# write shape and number of encoded latents
write_uints(f, (shape[0], shape[1], len(out["strings"])))
for s in out["strings"]:
write_uints(f, (len(s[0]),))
write_bytes(f, s[0])
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
x_recon = crop(recon_out["x_hat"], (h, w))
psnr=compute_psnr(x, x_recon)
else:
diff=x-ref
#1
diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5
#2
'''
diff1=torch.clamp(diff, min=0.0, max=1.0)
diff2=-torch.clamp(diff, min=-1.0, max=0.0)
diff1=pad(diff1, p)
diff2=pad(diff2, p)
'''
#1
with torch.no_grad():
out1 = net.compress(diff1)
shape1 = out1["shape"]
strings = []
with Path(output).open("ab") as f:
# write shape and number of encoded latents
write_uints(f, (shape1[0], shape1[1], len(out1["strings"])))
for s in out1["strings"]:
write_uints(f, (len(s[0]),))
write_bytes(f, s[0])
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"])))
x_hat1 = crop(recon_out["x_hat"], (h, w))
#2
'''
with torch.no_grad():
out1 = net.compress(diff1)
shape1 = out1["shape"]
strings = []
with Path(output).open("ab") as f:
# write shape and number of encoded latents
write_uints(f, (shape1[0], shape1[1], len(out1["strings"])))
for s in out1["strings"]:
write_uints(f, (len(s[0]),))
write_bytes(f, s[0])
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"])))
x_hat1 = crop(recon_out["x_hat"], (h, w))
with torch.no_grad():
out = net.compress(diff2)
shape = out["shape"]
strings = []
with Path(output).open("ab") as f:
# write shape and number of encoded latents
write_uints(f, (shape[0], shape[1], len(out["strings"])))
for s in out["strings"]:
write_uints(f, (len(s[0]),))
write_bytes(f, s[0])
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
x_hat2 = crop(recon_out["x_hat"], (h, w))
x_recon=ref+x_hat1-x_hat2
'''
x_recon=ref+x_hat1-0.5
psnr=compute_psnr(x, x_recon)
diff_img = torch2img(diff1)
diff_img.save(path+"recon/diff"+str(ff)+"_q"+str(quality)+".png")
enc_time = time.time() - enc_start
size = filesize(output)
bpp = float(size) * 8 / (img.size[0] * img.size[1]*3)
with Path(log_path).open("a") as f:
f.write( f" {bpp-total_bpp:.4f} | "
f" {psnr:.4f} |"
f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n")
recon_img = torch2img(x_recon)
recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png")
return psnr, bpp, x_recon, enc_time
def _decode(inputpath, coder, show, frame, output=None):
compressai.set_entropy_coder(coder)
dec_start = time.time()
with Path(inputpath).open("rb") as f:
model, metric, quality = parse_header(read_uchars(f, 2))
print(f"Model: {model:s}, metric: {metric:s}, quality: {quality:d}")
for i in range(frame):
original_size = read_uints(f, 2)
shape = read_uints(f, 2)
strings = []
n_strings = read_uints(f, 1)[0]
for _ in range(n_strings):
s = read_bytes(f, read_uints(f, 1)[0])
strings.append([s])
start = time.time()
net = models[model](quality=quality, metric=metric, pretrained=True).eval()
load_time = time.time() - start
with torch.no_grad():
out = net.decompress(strings, shape)
x_hat = crop(out["x_hat"], original_size)
img = torch2img(x_hat)
dec_time = time.time() - dec_start
print(f"Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)")
if show:
show_image(img)
if output is not None:
img.save(output+"_frame"+str(i)+".png")
def show_image(img: Image.Image):
from matplotlib import pyplot as plt
fig, ax = plt.subplots()
ax.axis("off")
ax.title.set_text("Decoded image")
ax.imshow(img)
fig.tight_layout()
plt.show()
def encode(argv):
parser = argparse.ArgumentParser(description="Encode image to bit-stream")
parser.add_argument("image", type=str)
parser.add_argument(
"--model",
choices=models.keys(),
default=list(models.keys())[0],
help="NN model to use (default: %(default)s)",
)
parser.add_argument(
"-m",
"--metric",
choices=["mse"],
default="mse",
help="metric trained against (default: %(default)s",
)
parser.add_argument(
"-q",
"--quality",
choices=list(range(1, 9)),
type=int,
default=3,
help="Quality setting (default: %(default)s)",
)
parser.add_argument(
"-c",
"--coder",
choices=compressai.available_entropy_coders(),
default=compressai.available_entropy_coders()[0],
help="Entropy coder (default: %(default)s)",
)
parser.add_argument(
"-f",
"--frame",
type=int,
default=100,
help="Frame setting (default: %(default)s)",
)
parser.add_argument(
"-fr",
"--framerate",
choices=[60,50,24],
type=int,
default=50,
help="Frame rate setting (default: %(default)s)",
)
parser.add_argument(
"-width",
"--width",
type=int,
default=768,
help="width setting (default: %(default))",
)
parser.add_argument(
"-height",
"--height",
type=int,
default=768,
help="hight setting (default: %(default))",
)
parser.add_argument("-o", "--output", help="Output path")
args = parser.parse_args(argv)
path="examples/"+args.image+"/"
if not args.output:
#args.output = Path(Path(args.image).resolve().name).with_suffix(".bin")
args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.bin"
log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.txt"
header = get_header(args.model, args.metric, args.quality)
with Path(args.output).open("wb") as f:
write_uchars(f, header)
write_uints(f, (args.height, args.width))
with Path(log_path).open("w") as f:
f.write(f"model : {args.model} | "
f"quality : {args.quality} | "
f"frames : {args.frame}\n")
f.write( f"frame | bpp | "
f" psnr |"
f" Encoded time (model loading)\n"
f" {0:3d} | ")
total_psnr=0.0
total_bpp=0.0
total_time=0.0
args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
img=args.image+"_frame"+str(0)+".png"
total_psnr, total_bpp, ref, total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
for ff in range(1, args.frame):
with Path(log_path).open("a") as f:
f.write(f" {ff:3d} | ")
img=args.image+"_frame"+str(ff)+".png"
psnr, total_bpp, ref, time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
total_psnr+=psnr
total_time+=time
total_psnr/=args.frame
total_bpp/=args.frame
with Path(log_path).open("a") as f:
f.write( f"\n Total Encoded time: {total_time:.2f}s\n"
f"\n Total PSNR: {total_psnr:.6f}\n"
f" Total BPP: {total_bpp:.6f}\n")
print(total_psnr)
print(total_bpp)
def decode(argv):
parser = argparse.ArgumentParser(description="Decode bit-stream to imager")
parser.add_argument("input", type=str)
parser.add_argument(
"-c",
"--coder",
choices=compressai.available_entropy_coders(),
default=compressai.available_entropy_coders()[0],
help="Entropy coder (default: %(default)s)",
)
parser.add_argument(
"-f",
"--frame",
choices=list(range(1, 600)),
type=int,
default=100,
help="Frame setting (default: %(default)s)",
)
parser.add_argument("--show", action="store_true")
parser.add_argument("-o", "--output", help="Output path")
args = parser.parse_args(argv)
args.input="examples/"+args.input+"/"+args.input+"_768x768_"+str(args.frame//2)+"_8bit_444.bin"
args.output="examples/recon/"+args.output+"/"+args.output+"_768x768_"+str(50)+"_8bit_444"
_decode(args.input, args.coder, args.show, args.frame, args.output)
def parse_args(argv):
parser = argparse.ArgumentParser(description="")
parser.add_argument("command", choices=["encode", "decode"])
args = parser.parse_args(argv)
return args
def main(argv):
args = parse_args(argv[1:2])
argv = argv[2:]
torch.set_num_threads(1) # just to be sure
if args.command == "encode":
encode(argv)
elif args.command == "decode":
decode(argv)
if __name__ == "__main__":
main(sys.argv)
# Copyright 2020 InterDigital Communications, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import struct
import sys
import time
import math
from pathlib import Path
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import ToPILImage, ToTensor
import compressai
from compressai.zoo import models
model_ids = {k: i for i, k in enumerate(models.keys())}
metric_ids = {
"mse": 0,
}
def inverse_dict(d):
# We assume dict values are unique...
assert len(d.keys()) == len(set(d.keys()))
return {v: k for k, v in d.items()}
def filesize(filepath: str) -> int:
if not Path(filepath).is_file():
raise ValueError(f'Invalid file "{filepath}".')
return Path(filepath).stat().st_size
def load_image(filepath: str) -> Image.Image:
return Image.open(filepath).convert("RGB")
def img2torch(img: Image.Image) -> torch.Tensor:
return ToTensor()(img).unsqueeze(0)
def torch2img(x: torch.Tensor) -> Image.Image:
return ToPILImage()(x.clamp_(0, 1).squeeze())
def write_uints(fd, values, fmt=">{:d}I"):
fd.write(struct.pack(fmt.format(len(values)), *values))
def write_uchars(fd, values, fmt=">{:d}B"):
fd.write(struct.pack(fmt.format(len(values)), *values))
def read_uints(fd, n, fmt=">{:d}I"):
sz = struct.calcsize("I")
return struct.unpack(fmt.format(n), fd.read(n * sz))
def read_uchars(fd, n, fmt=">{:d}B"):
sz = struct.calcsize("B")
return struct.unpack(fmt.format(n), fd.read(n * sz))
def write_bytes(fd, values, fmt=">{:d}s"):
if len(values) == 0:
return
fd.write(struct.pack(fmt.format(len(values)), values))
def read_bytes(fd, n, fmt=">{:d}s"):
sz = struct.calcsize("s")
return struct.unpack(fmt.format(n), fd.read(n * sz))[0]
def get_header(model_name, metric, quality):
"""Format header information:
- 1 byte for model id
- 4 bits for metric
- 4 bits for quality param
"""
metric = metric_ids[metric]
code = (metric << 4) | (quality - 1 & 0x0F)
return model_ids[model_name], code
def parse_header(header):
"""Read header information from 2 bytes:
- 1 byte for model id
- 4 bits for metric
- 4 bits for quality param
"""
model_id, code = header
quality = (code & 0x0F) + 1
metric = code >> 4
return (
inverse_dict(model_ids)[model_id],
inverse_dict(metric_ids)[metric],
quality,
)
def pad(x, p=2 ** 6):
h, w = x.size(2), x.size(3)
H = (h + p - 1) // p * p
W = (w + p - 1) // p * p
padding_left = (W - w) // 2
padding_right = W - w - padding_left
padding_top = (H - h) // 2
padding_bottom = H - h - padding_top
return F.pad(
x,
(padding_left, padding_right, padding_top, padding_bottom),
mode="constant",
value=0,
)
def crop(x, size):
H, W = x.size(2), x.size(3)
h, w = size
padding_left = (W - w) // 2
padding_right = W - w - padding_left
padding_top = (H - h) // 2
padding_bottom = H - h - padding_top
return F.pad(
x,
(-padding_left, -padding_right, -padding_top, -padding_bottom),
mode="constant",
value=0,
)
def compute_psnr(a, b):
mse = torch.mean((a - b)**2).item()
return -10 * math.log10(mse)
def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path):
compressai.set_entropy_coder(coder)
enc_start = time.time()
img = load_image(image)
start = time.time()
net = models[model](quality=quality, metric=metric, pretrained=True).eval()
load_time = time.time() - start
x = img2torch(img)
h, w = x.size(2), x.size(3)
p = 64 # maximum 6 strides of 2
x = pad(x, p)
# header = get_header(model, metric, quality)
if i==True:
strings = []
with torch.no_grad():
out = net.compress(x)
shape = out["shape"]
with Path(output).open("ab") as f:
# write shape and number of encoded latents
write_uints(f, (shape[0], shape[1], len(out["strings"])))
for s in out["strings"]:
write_uints(f, (len(s[0]),))
write_bytes(f, s[0])
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
x_recon = crop(recon_out["x_hat"], (h, w))
psnr=compute_psnr(x, x_recon)
else:
diff=x-ref
#2
diff1=torch.clamp(diff, min=0.0, max=1.0)
diff2=-torch.clamp(diff, min=-1.0, max=0.0)
diff1=pad(diff1, p)
diff2=pad(diff2, p)
#2
with torch.no_grad():
out1 = net.compress(diff1)
shape1 = out1["shape"]
strings = []
with Path(output).open("ab") as f:
# write shape and number of encoded latents
write_uints(f, (shape1[0], shape1[1], len(out1["strings"])))
for s in out1["strings"]:
write_uints(f, (len(s[0]),))
write_bytes(f, s[0])
strings.append([s[0]])
with torch.no_grad():
recon_out1 = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"])))
x_hat1 = crop(recon_out1["x_hat"], (h, w))
with torch.no_grad():
out = net.compress(diff2)
shape = out["shape"]
strings = []
with Path(output).open("ab") as f:
# write shape and number of encoded latents
write_uints(f, (shape[0], shape[1], len(out["strings"])))
for s in out["strings"]:
write_uints(f, (len(s[0]),))
write_bytes(f, s[0])
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
x_hat2 = crop(recon_out["x_hat"], (h, w))
x_recon=ref+x_hat1-x_hat2
psnr=compute_psnr(x, x_recon)
diff_img = torch2img(diff1)
diff_img.save(path+"recon/v3_diff_1_"+str(ff)+"_q"+str(quality)+".png")
diff_img = torch2img(diff2)
diff_img.save(path+"recon/v3_diff_2_"+str(ff)+"_q"+str(quality)+".png")
enc_time = time.time() - enc_start
size = filesize(output)
bpp = float(size) * 8 / (img.size[0] * img.size[1]*3)
with Path(log_path).open("a") as f:
f.write( f" {bpp-total_bpp:.4f} | "
f" {psnr:.4f} |"
f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n")
recon_img = torch2img(x_recon)
recon_img.save(path+"recon/v3_recon"+str(ff)+"_q"+str(quality)+".png")
return psnr, bpp, x_recon, enc_time
def _decode(inputpath, coder, show, frame, output=None):
compressai.set_entropy_coder(coder)
dec_start = time.time()
with Path(inputpath).open("rb") as f:
model, metric, quality = parse_header(read_uchars(f, 2))
print(f"Model: {model:s}, metric: {metric:s}, quality: {quality:d}")
for i in range(frame):
original_size = read_uints(f, 2)
shape = read_uints(f, 2)
strings = []
n_strings = read_uints(f, 1)[0]
for _ in range(n_strings):
s = read_bytes(f, read_uints(f, 1)[0])
strings.append([s])
start = time.time()
net = models[model](quality=quality, metric=metric, pretrained=True).eval()
load_time = time.time() - start
with torch.no_grad():
out = net.decompress(strings, shape)
x_hat = crop(out["x_hat"], original_size)
img = torch2img(x_hat)
dec_time = time.time() - dec_start
print(f"Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)")
if show:
show_image(img)
if output is not None:
img.save(output+"_frame"+str(i)+".png")
def show_image(img: Image.Image):
from matplotlib import pyplot as plt
fig, ax = plt.subplots()
ax.axis("off")
ax.title.set_text("Decoded image")
ax.imshow(img)
fig.tight_layout()
plt.show()
def encode(argv):
parser = argparse.ArgumentParser(description="Encode image to bit-stream")
parser.add_argument("image", type=str)
parser.add_argument(
"--model",
choices=models.keys(),
default=list(models.keys())[0],
help="NN model to use (default: %(default)s)",
)
parser.add_argument(
"-m",
"--metric",
choices=["mse"],
default="mse",
help="metric trained against (default: %(default)s",
)
parser.add_argument(
"-q",
"--quality",
choices=list(range(1, 9)),
type=int,
default=3,
help="Quality setting (default: %(default)s)",
)
parser.add_argument(
"-c",
"--coder",
choices=compressai.available_entropy_coders(),
default=compressai.available_entropy_coders()[0],
help="Entropy coder (default: %(default)s)",
)
parser.add_argument(
"-f",
"--frame",
type=int,
default=100,
help="Frame setting (default: %(default)s)",
)
parser.add_argument(
"-fr",
"--framerate",
choices=[60,50,24],
type=int,
default=50,
help="Frame rate setting (default: %(default)s)",
)
parser.add_argument(
"-width",
"--width",
type=int,
default=768,
help="width setting (default: %(default))",
)
parser.add_argument(
"-hight",
"--hight",
type=int,
default=768,
help="hight setting (default: %(default))",
)
parser.add_argument("-o", "--output", help="Output path")
args = parser.parse_args(argv)
path="examples/"+args.image+"/"
if not args.output:
#args.output = Path(Path(args.image).resolve().name).with_suffix(".bin")
args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v3.bin"
log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v3.txt"
header = get_header(args.model, args.metric, args.quality)
with Path(args.output).open("wb") as f:
write_uchars(f, header)
write_uints(f, (args.hight, args.width))
with Path(log_path).open("w") as f:
f.write(f"model : {args.model} | "
f"quality : {args.quality} | "
f"frames : {args.frame}\n")
f.write( f"frame | bpp | "
f" psnr |"
f" Encoded time (model loading)\n"
f" {0:3d} | ")
total_psnr=0.0
total_bpp=0.0
total_time=0.0
args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
img=args.image+"_frame"+str(0)+".png"
total_psnr, total_bpp, ref, total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
for ff in range(1, args.frame):
with Path(log_path).open("a") as f:
f.write(f" {ff:3d} | ")
img=args.image+"_frame"+str(ff)+".png"
psnr, total_bpp, ref, time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
total_psnr+=psnr
total_time+=time
total_psnr/=args.frame
total_bpp/=args.frame
with Path(log_path).open("a") as f:
f.write( f"\n Total Encoded time: {total_time:.2f}s\n"
f"\n Total PSNR: {total_psnr:.6f}\n"
f" Total BPP: {total_bpp:.6f}\n")
print(total_psnr)
print(total_bpp)
def decode(argv):
parser = argparse.ArgumentParser(description="Decode bit-stream to imager")
parser.add_argument("input", type=str)
parser.add_argument(
"-c",
"--coder",
choices=compressai.available_entropy_coders(),
default=compressai.available_entropy_coders()[0],
help="Entropy coder (default: %(default)s)",
)
parser.add_argument(
"-f",
"--frame",
choices=list(range(1, 600)),
type=int,
default=100,
help="Frame setting (default: %(default)s)",
)
parser.add_argument("--show", action="store_true")
parser.add_argument("-o", "--output", help="Output path")
args = parser.parse_args(argv)
args.input="examples/"+args.input+"/"+args.input+"_768x768_"+str(args.frame//2)+"_8bit_444_v3.bin"
args.output="examples/recon/"+args.output+"/"+args.output+"_768x768_"+str(50)+"_8bit_444"
_decode(args.input, args.coder, args.show, args.frame, args.output)
def parse_args(argv):
parser = argparse.ArgumentParser(description="")
parser.add_argument("command", choices=["encode", "decode"])
args = parser.parse_args(argv)
return args
def main(argv):
args = parse_args(argv[1:2])
argv = argv[2:]
torch.set_num_threads(1) # just to be sure
if args.command == "encode":
encode(argv)
elif args.command == "decode":
decode(argv)
if __name__ == "__main__":
main(sys.argv)