choiseungmi

Upload our codec

1 +# Copyright 2020 InterDigital Communications, Inc.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import argparse
16 +import struct
17 +import sys
18 +import time
19 +import math
20 +
21 +from pathlib import Path
22 +
23 +import torch
24 +import torch.nn.functional as F
25 +
26 +from PIL import Image
27 +from torchvision.transforms import ToPILImage, ToTensor
28 +
29 +import compressai
30 +
31 +from compressai.zoo import models
32 +
33 +model_ids = {k: i for i, k in enumerate(models.keys())}
34 +
35 +metric_ids = {
36 + "mse": 0,
37 +}
38 +
39 +
40 +def inverse_dict(d):
41 + # We assume dict values are unique...
42 + assert len(d.keys()) == len(set(d.keys()))
43 + return {v: k for k, v in d.items()}
44 +
45 +
46 +def filesize(filepath: str) -> int:
47 + if not Path(filepath).is_file():
48 + raise ValueError(f'Invalid file "{filepath}".')
49 + return Path(filepath).stat().st_size
50 +
51 +
52 +def load_image(filepath: str) -> Image.Image:
53 + return Image.open(filepath).convert("RGB")
54 +
55 +
56 +def img2torch(img: Image.Image) -> torch.Tensor:
57 + return ToTensor()(img).unsqueeze(0)
58 +
59 +
60 +def torch2img(x: torch.Tensor) -> Image.Image:
61 + return ToPILImage()(x.clamp_(0, 1).squeeze())
62 +
63 +
64 +def write_uints(fd, values, fmt=">{:d}I"):
65 + fd.write(struct.pack(fmt.format(len(values)), *values))
66 +
67 +
68 +def write_uchars(fd, values, fmt=">{:d}B"):
69 + fd.write(struct.pack(fmt.format(len(values)), *values))
70 +
71 +
72 +def read_uints(fd, n, fmt=">{:d}I"):
73 + sz = struct.calcsize("I")
74 + return struct.unpack(fmt.format(n), fd.read(n * sz))
75 +
76 +
77 +def read_uchars(fd, n, fmt=">{:d}B"):
78 + sz = struct.calcsize("B")
79 + return struct.unpack(fmt.format(n), fd.read(n * sz))
80 +
81 +
82 +def write_bytes(fd, values, fmt=">{:d}s"):
83 + if len(values) == 0:
84 + return
85 + fd.write(struct.pack(fmt.format(len(values)), values))
86 +
87 +
88 +def read_bytes(fd, n, fmt=">{:d}s"):
89 + sz = struct.calcsize("s")
90 + return struct.unpack(fmt.format(n), fd.read(n * sz))[0]
91 +
92 +
93 +def get_header(model_name, metric, quality):
94 + """Format header information:
95 + - 1 byte for model id
96 + - 4 bits for metric
97 + - 4 bits for quality param
98 + """
99 + metric = metric_ids[metric]
100 + code = (metric << 4) | (quality - 1 & 0x0F)
101 + return model_ids[model_name], code
102 +
103 +
104 +def parse_header(header):
105 + """Read header information from 2 bytes:
106 + - 1 byte for model id
107 + - 4 bits for metric
108 + - 4 bits for quality param
109 + """
110 + model_id, code = header
111 + quality = (code & 0x0F) + 1
112 + metric = code >> 4
113 + return (
114 + inverse_dict(model_ids)[model_id],
115 + inverse_dict(metric_ids)[metric],
116 + quality,
117 + )
118 +
119 +
120 +def pad(x, p=2 ** 6):
121 + h, w = x.size(2), x.size(3)
122 + H = (h + p - 1) // p * p
123 + W = (w + p - 1) // p * p
124 + padding_left = (W - w) // 2
125 + padding_right = W - w - padding_left
126 + padding_top = (H - h) // 2
127 + padding_bottom = H - h - padding_top
128 + return F.pad(
129 + x,
130 + (padding_left, padding_right, padding_top, padding_bottom),
131 + mode="constant",
132 + value=0,
133 + )
134 +
135 +
136 +def crop(x, size):
137 + H, W = x.size(2), x.size(3)
138 + h, w = size
139 + padding_left = (W - w) // 2
140 + padding_right = W - w - padding_left
141 + padding_top = (H - h) // 2
142 + padding_bottom = H - h - padding_top
143 + return F.pad(
144 + x,
145 + (-padding_left, -padding_right, -padding_top, -padding_bottom),
146 + mode="constant",
147 + value=0,
148 + )
149 +
150 +def compute_psnr(a, b):
151 + mse = torch.mean((a - b)**2).item()
152 + return -10 * math.log10(mse)
153 +
154 +def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path):
155 + compressai.set_entropy_coder(coder)
156 + enc_start = time.time()
157 +
158 + img = load_image(image)
159 + start = time.time()
160 + net = models[model](quality=quality, metric=metric, pretrained=True).eval()
161 + load_time = time.time() - start
162 +
163 + x = img2torch(img)
164 + h, w = x.size(2), x.size(3)
165 + p = 64 # maximum 6 strides of 2
166 + x = pad(x, p)
167 +
168 +# header = get_header(model, metric, quality)
169 +
170 + strings = []
171 +
172 + with torch.no_grad():
173 + out = net.compress(x)
174 + shape = out["shape"]
175 + with Path(output).open("ab") as f:
176 + # write shape and number of encoded latents
177 + write_uints(f, (shape[0], shape[1], len(out["strings"])))
178 +
179 + for s in out["strings"]:
180 + write_uints(f, (len(s[0]),))
181 + write_bytes(f, s[0])
182 + strings.append([s[0]])
183 +
184 + with torch.no_grad():
185 + recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
186 + x_recon = crop(recon_out["x_hat"], (h, w))
187 +
188 + psnr=compute_psnr(x, x_recon)
189 +
190 + if i==False:
191 + diff=x-ref
192 + diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5
193 + diff_img = torch2img(diff1)
194 + diff_img.save(path+"recon/diff_v1_"+str(ff)+"_q"+str(quality)+".png")
195 +
196 + enc_time = time.time() - enc_start
197 + size = filesize(output)
198 + bpp = float(size) * 8 / (img.size[0] * img.size[1]*3)
199 + with Path(log_path).open("a") as f:
200 + f.write( f" {bpp-total_bpp:.4f} | "
201 + f" {psnr:.4f} |"
202 + f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n")
203 + recon_img = torch2img(x_recon)
204 + recon_img.save(path+"recon/v1_recon"+str(ff)+"_q"+str(quality)+".png")
205 +
206 + return psnr, bpp, x_recon, enc_time
207 +
208 +
209 +def _decode(inputpath, coder, show, frame, output=None):
210 + compressai.set_entropy_coder(coder)
211 + dec_start = time.time()
212 +
213 + with Path(inputpath).open("rb") as f:
214 + model, metric, quality = parse_header(read_uchars(f, 2))
215 + print(f"Model: {model:s}, metric: {metric:s}, quality: {quality:d}")
216 +
217 + for i in range(frame):
218 + original_size = read_uints(f, 2)
219 + shape = read_uints(f, 2)
220 + strings = []
221 + n_strings = read_uints(f, 1)[0]
222 + for _ in range(n_strings):
223 + s = read_bytes(f, read_uints(f, 1)[0])
224 + strings.append([s])
225 +
226 + start = time.time()
227 + net = models[model](quality=quality, metric=metric, pretrained=True).eval()
228 + load_time = time.time() - start
229 +
230 + with torch.no_grad():
231 + out = net.decompress(strings, shape)
232 +
233 + x_hat = crop(out["x_hat"], original_size)
234 + img = torch2img(x_hat)
235 + dec_time = time.time() - dec_start
236 + print(f"Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)")
237 +
238 + if show:
239 + show_image(img)
240 + if output is not None:
241 + img.save(output+"_frame"+str(i)+".png")
242 +
243 +def show_image(img: Image.Image):
244 + from matplotlib import pyplot as plt
245 +
246 + fig, ax = plt.subplots()
247 + ax.axis("off")
248 + ax.title.set_text("Decoded image")
249 + ax.imshow(img)
250 + fig.tight_layout()
251 + plt.show()
252 +
253 +
254 +def encode(argv):
255 + parser = argparse.ArgumentParser(description="Encode image to bit-stream")
256 + parser.add_argument("image", type=str)
257 + parser.add_argument(
258 + "--model",
259 + choices=models.keys(),
260 + default=list(models.keys())[0],
261 + help="NN model to use (default: %(default)s)",
262 + )
263 + parser.add_argument(
264 + "-m",
265 + "--metric",
266 + choices=["mse"],
267 + default="mse",
268 + help="metric trained against (default: %(default)s",
269 + )
270 + parser.add_argument(
271 + "-q",
272 + "--quality",
273 + choices=list(range(1, 9)),
274 + type=int,
275 + default=3,
276 + help="Quality setting (default: %(default)s)",
277 + )
278 + parser.add_argument(
279 + "-c",
280 + "--coder",
281 + choices=compressai.available_entropy_coders(),
282 + default=compressai.available_entropy_coders()[0],
283 + help="Entropy coder (default: %(default)s)",
284 + )
285 + parser.add_argument(
286 + "-f",
287 + "--frame",
288 + type=int,
289 + default=100,
290 + help="Frame setting (default: %(default)s)",
291 + )
292 + parser.add_argument(
293 + "-fr",
294 + "--framerate",
295 + choices=[60,50,24],
296 + type=int,
297 + default=50,
298 + help="Frame rate setting (default: %(default)s)",
299 + )
300 + parser.add_argument(
301 + "-width",
302 + "--width",
303 + type=int,
304 + default=768,
305 + help="width setting (default: %(default))",
306 + )
307 + parser.add_argument(
308 + "-hight",
309 + "--hight",
310 + type=int,
311 + default=768,
312 + help="hight setting (default: %(default))",
313 + )
314 + parser.add_argument("-o", "--output", help="Output path")
315 + args = parser.parse_args(argv)
316 + path="examples/"+args.image+"/"
317 + if not args.output:
318 + #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin")
319 + args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v1.bin"
320 + log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v1.txt"
321 +
322 + header = get_header(args.model, args.metric, args.quality)
323 + with Path(args.output).open("wb") as f:
324 + write_uchars(f, header)
325 + write_uints(f, (args.hight, args.width))
326 +
327 + with Path(log_path).open("w") as f:
328 + f.write(f"model : {args.model} | "
329 + f"quality : {args.quality} | "
330 + f"frames : {args.frame}\n")
331 + f.write( f"frame | bpp | "
332 + f" psnr |"
333 + f" Encoded time (model loading)\n"
334 + f" {0:3d} | ")
335 +
336 + total_psnr=0.0
337 + total_bpp=0.0
338 + total_time=0.0
339 + args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
340 + img=args.image+"_frame"+str(0)+".png"
341 + total_psnr, total_bpp, ref,total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
342 + for ff in range(1, args.frame):
343 + with Path(log_path).open("a") as f:
344 + f.write(f" {ff:3d} | ")
345 + img=args.image+"_frame"+str(ff)+".png"
346 +
347 + psnr, total_bpp, ref,time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
348 + total_psnr+=psnr
349 + total_time+=time
350 +
351 + total_psnr/=args.frame
352 + total_bpp/=args.frame
353 +
354 + with Path(log_path).open("a") as f:
355 + f.write( f"\n Total Encoded time: {total_time:.2f}s\n"
356 + f"\n Total PSNR: {total_psnr:.6f}\n"
357 + f" Total BPP: {total_bpp:.6f}\n")
358 + print(total_psnr)
359 + print(total_bpp)
360 +
361 +
362 +def decode(argv):
363 + parser = argparse.ArgumentParser(description="Decode bit-stream to imager")
364 + parser.add_argument("input", type=str)
365 + parser.add_argument(
366 + "-c",
367 + "--coder",
368 + choices=compressai.available_entropy_coders(),
369 + default=compressai.available_entropy_coders()[0],
370 + help="Entropy coder (default: %(default)s)",
371 + )
372 + parser.add_argument(
373 + "-f",
374 + "--frame",
375 + choices=list(range(1, 600)),
376 + type=int,
377 + default=100,
378 + help="Frame setting (default: %(default)s)",
379 + )
380 + parser.add_argument("--show", action="store_true")
381 + parser.add_argument("-o", "--output", help="Output path")
382 + args = parser.parse_args(argv)
383 +
384 + args.input="examples/"+args.input+"/"+args.input+"_768x768_"+str(args.frame//2)+"_8bit_444.bin"
385 + args.output="examples/recon/"+args.output+"/"+args.output+"_768x768_"+str(50)+"_8bit_444"
386 + _decode(args.input, args.coder, args.show, args.frame, args.output)
387 +
388 +
389 +def parse_args(argv):
390 + parser = argparse.ArgumentParser(description="")
391 + parser.add_argument("command", choices=["encode", "decode"])
392 + args = parser.parse_args(argv)
393 + return args
394 +
395 +
396 +def main(argv):
397 + args = parse_args(argv[1:2])
398 + argv = argv[2:]
399 + torch.set_num_threads(1) # just to be sure
400 + if args.command == "encode":
401 + encode(argv)
402 + elif args.command == "decode":
403 + decode(argv)
404 +
405 +
406 +if __name__ == "__main__":
407 + main(sys.argv)
1 +# Copyright 2020 InterDigital Communications, Inc.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import argparse
16 +import struct
17 +import sys
18 +import time
19 +import math
20 +
21 +from pathlib import Path
22 +
23 +import torch
24 +import torch.nn.functional as F
25 +
26 +from PIL import Image
27 +from torchvision.transforms import ToPILImage, ToTensor
28 +
29 +import compressai
30 +
31 +from compressai.zoo import models
32 +
33 +model_ids = {k: i for i, k in enumerate(models.keys())}
34 +
35 +metric_ids = {
36 + "mse": 0,
37 +}
38 +
39 +
40 +def inverse_dict(d):
41 + # We assume dict values are unique...
42 + assert len(d.keys()) == len(set(d.keys()))
43 + return {v: k for k, v in d.items()}
44 +
45 +
46 +def filesize(filepath: str) -> int:
47 + if not Path(filepath).is_file():
48 + raise ValueError(f'Invalid file "{filepath}".')
49 + return Path(filepath).stat().st_size
50 +
51 +
52 +def load_image(filepath: str) -> Image.Image:
53 + return Image.open(filepath).convert("RGB")
54 +
55 +
56 +def img2torch(img: Image.Image) -> torch.Tensor:
57 + return ToTensor()(img).unsqueeze(0)
58 +
59 +
60 +def torch2img(x: torch.Tensor) -> Image.Image:
61 + return ToPILImage()(x.clamp_(0, 1).squeeze())
62 +
63 +
64 +def write_uints(fd, values, fmt=">{:d}I"):
65 + fd.write(struct.pack(fmt.format(len(values)), *values))
66 +
67 +
68 +def write_uchars(fd, values, fmt=">{:d}B"):
69 + fd.write(struct.pack(fmt.format(len(values)), *values))
70 +
71 +
72 +def read_uints(fd, n, fmt=">{:d}I"):
73 + sz = struct.calcsize("I")
74 + return struct.unpack(fmt.format(n), fd.read(n * sz))
75 +
76 +
77 +def read_uchars(fd, n, fmt=">{:d}B"):
78 + sz = struct.calcsize("B")
79 + return struct.unpack(fmt.format(n), fd.read(n * sz))
80 +
81 +
82 +def write_bytes(fd, values, fmt=">{:d}s"):
83 + if len(values) == 0:
84 + return
85 + fd.write(struct.pack(fmt.format(len(values)), values))
86 +
87 +
88 +def read_bytes(fd, n, fmt=">{:d}s"):
89 + sz = struct.calcsize("s")
90 + return struct.unpack(fmt.format(n), fd.read(n * sz))[0]
91 +
92 +
93 +def get_header(model_name, metric, quality):
94 + """Format header information:
95 + - 1 byte for model id
96 + - 4 bits for metric
97 + - 4 bits for quality param
98 + """
99 + metric = metric_ids[metric]
100 + code = (metric << 4) | (quality - 1 & 0x0F)
101 + return model_ids[model_name], code
102 +
103 +
104 +def parse_header(header):
105 + """Read header information from 2 bytes:
106 + - 1 byte for model id
107 + - 4 bits for metric
108 + - 4 bits for quality param
109 + """
110 + model_id, code = header
111 + quality = (code & 0x0F) + 1
112 + metric = code >> 4
113 + return (
114 + inverse_dict(model_ids)[model_id],
115 + inverse_dict(metric_ids)[metric],
116 + quality,
117 + )
118 +
119 +
120 +def pad(x, p=2 ** 6):
121 + h, w = x.size(2), x.size(3)
122 + H = (h + p - 1) // p * p
123 + W = (w + p - 1) // p * p
124 + padding_left = (W - w) // 2
125 + padding_right = W - w - padding_left
126 + padding_top = (H - h) // 2
127 + padding_bottom = H - h - padding_top
128 + return F.pad(
129 + x,
130 + (padding_left, padding_right, padding_top, padding_bottom),
131 + mode="constant",
132 + value=0,
133 + )
134 +
135 +
136 +def crop(x, size):
137 + H, W = x.size(2), x.size(3)
138 + h, w = size
139 + padding_left = (W - w) // 2
140 + padding_right = W - w - padding_left
141 + padding_top = (H - h) // 2
142 + padding_bottom = H - h - padding_top
143 + return F.pad(
144 + x,
145 + (-padding_left, -padding_right, -padding_top, -padding_bottom),
146 + mode="constant",
147 + value=0,
148 + )
149 +
150 +def compute_psnr(a, b):
151 + mse = torch.mean((a - b)**2).item()
152 + return -10 * math.log10(mse)
153 +
154 +def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path):
155 + compressai.set_entropy_coder(coder)
156 + enc_start = time.time()
157 +
158 + img = load_image(image)
159 + start = time.time()
160 + net = models[model](quality=quality, metric=metric, pretrained=True).eval()
161 + load_time = time.time() - start
162 +
163 + x = img2torch(img)
164 + h, w = x.size(2), x.size(3)
165 + p = 64 # maximum 6 strides of 2
166 + x = pad(x, p)
167 +
168 +# header = get_header(model, metric, quality)
169 + if i==True:
170 + strings = []
171 +
172 + with torch.no_grad():
173 + out = net.compress(x)
174 + shape = out["shape"]
175 + with Path(output).open("ab") as f:
176 + # write shape and number of encoded latents
177 + write_uints(f, (shape[0], shape[1], len(out["strings"])))
178 +
179 + for s in out["strings"]:
180 + write_uints(f, (len(s[0]),))
181 + write_bytes(f, s[0])
182 + strings.append([s[0]])
183 +
184 + with torch.no_grad():
185 + recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
186 + x_recon = crop(recon_out["x_hat"], (h, w))
187 +
188 + psnr=compute_psnr(x, x_recon)
189 + else:
190 + diff=x-ref
191 + #1
192 + diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5
193 +
194 + #2
195 + '''
196 + diff1=torch.clamp(diff, min=0.0, max=1.0)
197 + diff2=-torch.clamp(diff, min=-1.0, max=0.0)
198 +
199 + diff1=pad(diff1, p)
200 + diff2=pad(diff2, p)
201 + '''
202 + #1
203 +
204 + with torch.no_grad():
205 + out1 = net.compress(diff1)
206 + shape1 = out1["shape"]
207 + strings = []
208 +
209 + with Path(output).open("ab") as f:
210 + # write shape and number of encoded latents
211 + write_uints(f, (shape1[0], shape1[1], len(out1["strings"])))
212 +
213 + for s in out1["strings"]:
214 + write_uints(f, (len(s[0]),))
215 + write_bytes(f, s[0])
216 + strings.append([s[0]])
217 +
218 + with torch.no_grad():
219 + recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"])))
220 + x_hat1 = crop(recon_out["x_hat"], (h, w))
221 +
222 + #2
223 + '''
224 + with torch.no_grad():
225 + out1 = net.compress(diff1)
226 + shape1 = out1["shape"]
227 + strings = []
228 +
229 + with Path(output).open("ab") as f:
230 + # write shape and number of encoded latents
231 + write_uints(f, (shape1[0], shape1[1], len(out1["strings"])))
232 +
233 + for s in out1["strings"]:
234 + write_uints(f, (len(s[0]),))
235 + write_bytes(f, s[0])
236 + strings.append([s[0]])
237 +
238 + with torch.no_grad():
239 + recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"])))
240 + x_hat1 = crop(recon_out["x_hat"], (h, w))
241 + with torch.no_grad():
242 + out = net.compress(diff2)
243 + shape = out["shape"]
244 + strings = []
245 +
246 + with Path(output).open("ab") as f:
247 + # write shape and number of encoded latents
248 + write_uints(f, (shape[0], shape[1], len(out["strings"])))
249 +
250 + for s in out["strings"]:
251 + write_uints(f, (len(s[0]),))
252 + write_bytes(f, s[0])
253 + strings.append([s[0]])
254 +
255 + with torch.no_grad():
256 + recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
257 + x_hat2 = crop(recon_out["x_hat"], (h, w))
258 + x_recon=ref+x_hat1-x_hat2
259 + '''
260 +
261 + x_recon=ref+x_hat1-0.5
262 + psnr=compute_psnr(x, x_recon)
263 + diff_img = torch2img(diff1)
264 + diff_img.save(path+"recon/diff"+str(ff)+"_q"+str(quality)+".png")
265 +
266 + enc_time = time.time() - enc_start
267 + size = filesize(output)
268 + bpp = float(size) * 8 / (img.size[0] * img.size[1]*3)
269 + with Path(log_path).open("a") as f:
270 + f.write( f" {bpp-total_bpp:.4f} | "
271 + f" {psnr:.4f} |"
272 + f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n")
273 + recon_img = torch2img(x_recon)
274 + recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png")
275 +
276 + return psnr, bpp, x_recon, enc_time
277 +
278 +
279 +def _decode(inputpath, coder, show, frame, output=None):
280 + compressai.set_entropy_coder(coder)
281 + dec_start = time.time()
282 +
283 + with Path(inputpath).open("rb") as f:
284 + model, metric, quality = parse_header(read_uchars(f, 2))
285 + print(f"Model: {model:s}, metric: {metric:s}, quality: {quality:d}")
286 +
287 + for i in range(frame):
288 + original_size = read_uints(f, 2)
289 + shape = read_uints(f, 2)
290 + strings = []
291 + n_strings = read_uints(f, 1)[0]
292 + for _ in range(n_strings):
293 + s = read_bytes(f, read_uints(f, 1)[0])
294 + strings.append([s])
295 +
296 + start = time.time()
297 + net = models[model](quality=quality, metric=metric, pretrained=True).eval()
298 + load_time = time.time() - start
299 +
300 + with torch.no_grad():
301 + out = net.decompress(strings, shape)
302 +
303 + x_hat = crop(out["x_hat"], original_size)
304 + img = torch2img(x_hat)
305 + dec_time = time.time() - dec_start
306 + print(f"Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)")
307 +
308 + if show:
309 + show_image(img)
310 + if output is not None:
311 + img.save(output+"_frame"+str(i)+".png")
312 +
313 +def show_image(img: Image.Image):
314 + from matplotlib import pyplot as plt
315 +
316 + fig, ax = plt.subplots()
317 + ax.axis("off")
318 + ax.title.set_text("Decoded image")
319 + ax.imshow(img)
320 + fig.tight_layout()
321 + plt.show()
322 +
323 +
324 +def encode(argv):
325 + parser = argparse.ArgumentParser(description="Encode image to bit-stream")
326 + parser.add_argument("image", type=str)
327 + parser.add_argument(
328 + "--model",
329 + choices=models.keys(),
330 + default=list(models.keys())[0],
331 + help="NN model to use (default: %(default)s)",
332 + )
333 + parser.add_argument(
334 + "-m",
335 + "--metric",
336 + choices=["mse"],
337 + default="mse",
338 + help="metric trained against (default: %(default)s",
339 + )
340 + parser.add_argument(
341 + "-q",
342 + "--quality",
343 + choices=list(range(1, 9)),
344 + type=int,
345 + default=3,
346 + help="Quality setting (default: %(default)s)",
347 + )
348 + parser.add_argument(
349 + "-c",
350 + "--coder",
351 + choices=compressai.available_entropy_coders(),
352 + default=compressai.available_entropy_coders()[0],
353 + help="Entropy coder (default: %(default)s)",
354 + )
355 + parser.add_argument(
356 + "-f",
357 + "--frame",
358 + type=int,
359 + default=100,
360 + help="Frame setting (default: %(default)s)",
361 + )
362 + parser.add_argument(
363 + "-fr",
364 + "--framerate",
365 + choices=[60,50,24],
366 + type=int,
367 + default=50,
368 + help="Frame rate setting (default: %(default)s)",
369 + )
370 + parser.add_argument(
371 + "-width",
372 + "--width",
373 + type=int,
374 + default=768,
375 + help="width setting (default: %(default))",
376 + )
377 + parser.add_argument(
378 + "-height",
379 + "--height",
380 + type=int,
381 + default=768,
382 + help="hight setting (default: %(default))",
383 + )
384 + parser.add_argument("-o", "--output", help="Output path")
385 + args = parser.parse_args(argv)
386 + path="examples/"+args.image+"/"
387 + if not args.output:
388 + #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin")
389 + args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.bin"
390 + log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.txt"
391 +
392 + header = get_header(args.model, args.metric, args.quality)
393 + with Path(args.output).open("wb") as f:
394 + write_uchars(f, header)
395 + write_uints(f, (args.height, args.width))
396 +
397 + with Path(log_path).open("w") as f:
398 + f.write(f"model : {args.model} | "
399 + f"quality : {args.quality} | "
400 + f"frames : {args.frame}\n")
401 + f.write( f"frame | bpp | "
402 + f" psnr |"
403 + f" Encoded time (model loading)\n"
404 + f" {0:3d} | ")
405 +
406 + total_psnr=0.0
407 + total_bpp=0.0
408 + total_time=0.0
409 + args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
410 + img=args.image+"_frame"+str(0)+".png"
411 + total_psnr, total_bpp, ref, total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
412 + for ff in range(1, args.frame):
413 + with Path(log_path).open("a") as f:
414 + f.write(f" {ff:3d} | ")
415 + img=args.image+"_frame"+str(ff)+".png"
416 +
417 + psnr, total_bpp, ref, time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
418 + total_psnr+=psnr
419 + total_time+=time
420 +
421 + total_psnr/=args.frame
422 + total_bpp/=args.frame
423 +
424 + with Path(log_path).open("a") as f:
425 + f.write( f"\n Total Encoded time: {total_time:.2f}s\n"
426 + f"\n Total PSNR: {total_psnr:.6f}\n"
427 + f" Total BPP: {total_bpp:.6f}\n")
428 + print(total_psnr)
429 + print(total_bpp)
430 +
431 +
432 +def decode(argv):
433 + parser = argparse.ArgumentParser(description="Decode bit-stream to imager")
434 + parser.add_argument("input", type=str)
435 + parser.add_argument(
436 + "-c",
437 + "--coder",
438 + choices=compressai.available_entropy_coders(),
439 + default=compressai.available_entropy_coders()[0],
440 + help="Entropy coder (default: %(default)s)",
441 + )
442 + parser.add_argument(
443 + "-f",
444 + "--frame",
445 + choices=list(range(1, 600)),
446 + type=int,
447 + default=100,
448 + help="Frame setting (default: %(default)s)",
449 + )
450 + parser.add_argument("--show", action="store_true")
451 + parser.add_argument("-o", "--output", help="Output path")
452 + args = parser.parse_args(argv)
453 +
454 + args.input="examples/"+args.input+"/"+args.input+"_768x768_"+str(args.frame//2)+"_8bit_444.bin"
455 + args.output="examples/recon/"+args.output+"/"+args.output+"_768x768_"+str(50)+"_8bit_444"
456 + _decode(args.input, args.coder, args.show, args.frame, args.output)
457 +
458 +
459 +def parse_args(argv):
460 + parser = argparse.ArgumentParser(description="")
461 + parser.add_argument("command", choices=["encode", "decode"])
462 + args = parser.parse_args(argv)
463 + return args
464 +
465 +
466 +def main(argv):
467 + args = parse_args(argv[1:2])
468 + argv = argv[2:]
469 + torch.set_num_threads(1) # just to be sure
470 + if args.command == "encode":
471 + encode(argv)
472 + elif args.command == "decode":
473 + decode(argv)
474 +
475 +
476 +if __name__ == "__main__":
477 + main(sys.argv)
1 +# Copyright 2020 InterDigital Communications, Inc.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import argparse
16 +import struct
17 +import sys
18 +import time
19 +import math
20 +
21 +from pathlib import Path
22 +
23 +import torch
24 +import torch.nn.functional as F
25 +
26 +from PIL import Image
27 +from torchvision.transforms import ToPILImage, ToTensor
28 +
29 +import compressai
30 +
31 +from compressai.zoo import models
32 +
33 +model_ids = {k: i for i, k in enumerate(models.keys())}
34 +
35 +metric_ids = {
36 + "mse": 0,
37 +}
38 +
39 +
40 +def inverse_dict(d):
41 + # We assume dict values are unique...
42 + assert len(d.keys()) == len(set(d.keys()))
43 + return {v: k for k, v in d.items()}
44 +
45 +
46 +def filesize(filepath: str) -> int:
47 + if not Path(filepath).is_file():
48 + raise ValueError(f'Invalid file "{filepath}".')
49 + return Path(filepath).stat().st_size
50 +
51 +
52 +def load_image(filepath: str) -> Image.Image:
53 + return Image.open(filepath).convert("RGB")
54 +
55 +
56 +def img2torch(img: Image.Image) -> torch.Tensor:
57 + return ToTensor()(img).unsqueeze(0)
58 +
59 +
60 +def torch2img(x: torch.Tensor) -> Image.Image:
61 + return ToPILImage()(x.clamp_(0, 1).squeeze())
62 +
63 +
64 +def write_uints(fd, values, fmt=">{:d}I"):
65 + fd.write(struct.pack(fmt.format(len(values)), *values))
66 +
67 +
68 +def write_uchars(fd, values, fmt=">{:d}B"):
69 + fd.write(struct.pack(fmt.format(len(values)), *values))
70 +
71 +
72 +def read_uints(fd, n, fmt=">{:d}I"):
73 + sz = struct.calcsize("I")
74 + return struct.unpack(fmt.format(n), fd.read(n * sz))
75 +
76 +
77 +def read_uchars(fd, n, fmt=">{:d}B"):
78 + sz = struct.calcsize("B")
79 + return struct.unpack(fmt.format(n), fd.read(n * sz))
80 +
81 +
82 +def write_bytes(fd, values, fmt=">{:d}s"):
83 + if len(values) == 0:
84 + return
85 + fd.write(struct.pack(fmt.format(len(values)), values))
86 +
87 +
88 +def read_bytes(fd, n, fmt=">{:d}s"):
89 + sz = struct.calcsize("s")
90 + return struct.unpack(fmt.format(n), fd.read(n * sz))[0]
91 +
92 +
93 +def get_header(model_name, metric, quality):
94 + """Format header information:
95 + - 1 byte for model id
96 + - 4 bits for metric
97 + - 4 bits for quality param
98 + """
99 + metric = metric_ids[metric]
100 + code = (metric << 4) | (quality - 1 & 0x0F)
101 + return model_ids[model_name], code
102 +
103 +
104 +def parse_header(header):
105 + """Read header information from 2 bytes:
106 + - 1 byte for model id
107 + - 4 bits for metric
108 + - 4 bits for quality param
109 + """
110 + model_id, code = header
111 + quality = (code & 0x0F) + 1
112 + metric = code >> 4
113 + return (
114 + inverse_dict(model_ids)[model_id],
115 + inverse_dict(metric_ids)[metric],
116 + quality,
117 + )
118 +
119 +
120 +def pad(x, p=2 ** 6):
121 + h, w = x.size(2), x.size(3)
122 + H = (h + p - 1) // p * p
123 + W = (w + p - 1) // p * p
124 + padding_left = (W - w) // 2
125 + padding_right = W - w - padding_left
126 + padding_top = (H - h) // 2
127 + padding_bottom = H - h - padding_top
128 + return F.pad(
129 + x,
130 + (padding_left, padding_right, padding_top, padding_bottom),
131 + mode="constant",
132 + value=0,
133 + )
134 +
135 +
136 +def crop(x, size):
137 + H, W = x.size(2), x.size(3)
138 + h, w = size
139 + padding_left = (W - w) // 2
140 + padding_right = W - w - padding_left
141 + padding_top = (H - h) // 2
142 + padding_bottom = H - h - padding_top
143 + return F.pad(
144 + x,
145 + (-padding_left, -padding_right, -padding_top, -padding_bottom),
146 + mode="constant",
147 + value=0,
148 + )
149 +
150 +def compute_psnr(a, b):
151 + mse = torch.mean((a - b)**2).item()
152 + return -10 * math.log10(mse)
153 +
154 +def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path):
155 + compressai.set_entropy_coder(coder)
156 + enc_start = time.time()
157 +
158 + img = load_image(image)
159 + start = time.time()
160 + net = models[model](quality=quality, metric=metric, pretrained=True).eval()
161 + load_time = time.time() - start
162 +
163 + x = img2torch(img)
164 + h, w = x.size(2), x.size(3)
165 + p = 64 # maximum 6 strides of 2
166 + x = pad(x, p)
167 +
168 +# header = get_header(model, metric, quality)
169 + if i==True:
170 + strings = []
171 +
172 + with torch.no_grad():
173 + out = net.compress(x)
174 + shape = out["shape"]
175 + with Path(output).open("ab") as f:
176 + # write shape and number of encoded latents
177 + write_uints(f, (shape[0], shape[1], len(out["strings"])))
178 +
179 + for s in out["strings"]:
180 + write_uints(f, (len(s[0]),))
181 + write_bytes(f, s[0])
182 + strings.append([s[0]])
183 +
184 + with torch.no_grad():
185 + recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
186 + x_recon = crop(recon_out["x_hat"], (h, w))
187 +
188 + psnr=compute_psnr(x, x_recon)
189 + else:
190 + diff=x-ref
191 + #2
192 +
193 + diff1=torch.clamp(diff, min=0.0, max=1.0)
194 + diff2=-torch.clamp(diff, min=-1.0, max=0.0)
195 +
196 + diff1=pad(diff1, p)
197 + diff2=pad(diff2, p)
198 +
199 + #2
200 +
201 + with torch.no_grad():
202 + out1 = net.compress(diff1)
203 + shape1 = out1["shape"]
204 + strings = []
205 +
206 + with Path(output).open("ab") as f:
207 + # write shape and number of encoded latents
208 + write_uints(f, (shape1[0], shape1[1], len(out1["strings"])))
209 +
210 + for s in out1["strings"]:
211 + write_uints(f, (len(s[0]),))
212 + write_bytes(f, s[0])
213 + strings.append([s[0]])
214 +
215 + with torch.no_grad():
216 + recon_out1 = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"])))
217 + x_hat1 = crop(recon_out1["x_hat"], (h, w))
218 +
219 + with torch.no_grad():
220 + out = net.compress(diff2)
221 + shape = out["shape"]
222 + strings = []
223 +
224 + with Path(output).open("ab") as f:
225 + # write shape and number of encoded latents
226 + write_uints(f, (shape[0], shape[1], len(out["strings"])))
227 +
228 + for s in out["strings"]:
229 + write_uints(f, (len(s[0]),))
230 + write_bytes(f, s[0])
231 + strings.append([s[0]])
232 +
233 + with torch.no_grad():
234 + recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
235 + x_hat2 = crop(recon_out["x_hat"], (h, w))
236 + x_recon=ref+x_hat1-x_hat2
237 +
238 + psnr=compute_psnr(x, x_recon)
239 + diff_img = torch2img(diff1)
240 + diff_img.save(path+"recon/v3_diff_1_"+str(ff)+"_q"+str(quality)+".png")
241 + diff_img = torch2img(diff2)
242 + diff_img.save(path+"recon/v3_diff_2_"+str(ff)+"_q"+str(quality)+".png")
243 +
244 + enc_time = time.time() - enc_start
245 + size = filesize(output)
246 + bpp = float(size) * 8 / (img.size[0] * img.size[1]*3)
247 + with Path(log_path).open("a") as f:
248 + f.write( f" {bpp-total_bpp:.4f} | "
249 + f" {psnr:.4f} |"
250 + f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n")
251 + recon_img = torch2img(x_recon)
252 + recon_img.save(path+"recon/v3_recon"+str(ff)+"_q"+str(quality)+".png")
253 +
254 + return psnr, bpp, x_recon, enc_time
255 +
256 +
257 +def _decode(inputpath, coder, show, frame, output=None):
258 + compressai.set_entropy_coder(coder)
259 + dec_start = time.time()
260 +
261 + with Path(inputpath).open("rb") as f:
262 + model, metric, quality = parse_header(read_uchars(f, 2))
263 + print(f"Model: {model:s}, metric: {metric:s}, quality: {quality:d}")
264 +
265 + for i in range(frame):
266 + original_size = read_uints(f, 2)
267 + shape = read_uints(f, 2)
268 + strings = []
269 + n_strings = read_uints(f, 1)[0]
270 + for _ in range(n_strings):
271 + s = read_bytes(f, read_uints(f, 1)[0])
272 + strings.append([s])
273 +
274 + start = time.time()
275 + net = models[model](quality=quality, metric=metric, pretrained=True).eval()
276 + load_time = time.time() - start
277 +
278 + with torch.no_grad():
279 + out = net.decompress(strings, shape)
280 +
281 + x_hat = crop(out["x_hat"], original_size)
282 + img = torch2img(x_hat)
283 + dec_time = time.time() - dec_start
284 + print(f"Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)")
285 +
286 + if show:
287 + show_image(img)
288 + if output is not None:
289 + img.save(output+"_frame"+str(i)+".png")
290 +
291 +def show_image(img: Image.Image):
292 + from matplotlib import pyplot as plt
293 +
294 + fig, ax = plt.subplots()
295 + ax.axis("off")
296 + ax.title.set_text("Decoded image")
297 + ax.imshow(img)
298 + fig.tight_layout()
299 + plt.show()
300 +
301 +
302 +def encode(argv):
303 + parser = argparse.ArgumentParser(description="Encode image to bit-stream")
304 + parser.add_argument("image", type=str)
305 + parser.add_argument(
306 + "--model",
307 + choices=models.keys(),
308 + default=list(models.keys())[0],
309 + help="NN model to use (default: %(default)s)",
310 + )
311 + parser.add_argument(
312 + "-m",
313 + "--metric",
314 + choices=["mse"],
315 + default="mse",
316 + help="metric trained against (default: %(default)s",
317 + )
318 + parser.add_argument(
319 + "-q",
320 + "--quality",
321 + choices=list(range(1, 9)),
322 + type=int,
323 + default=3,
324 + help="Quality setting (default: %(default)s)",
325 + )
326 + parser.add_argument(
327 + "-c",
328 + "--coder",
329 + choices=compressai.available_entropy_coders(),
330 + default=compressai.available_entropy_coders()[0],
331 + help="Entropy coder (default: %(default)s)",
332 + )
333 + parser.add_argument(
334 + "-f",
335 + "--frame",
336 + type=int,
337 + default=100,
338 + help="Frame setting (default: %(default)s)",
339 + )
340 + parser.add_argument(
341 + "-fr",
342 + "--framerate",
343 + choices=[60,50,24],
344 + type=int,
345 + default=50,
346 + help="Frame rate setting (default: %(default)s)",
347 + )
348 + parser.add_argument(
349 + "-width",
350 + "--width",
351 + type=int,
352 + default=768,
353 + help="width setting (default: %(default))",
354 + )
355 + parser.add_argument(
356 + "-hight",
357 + "--hight",
358 + type=int,
359 + default=768,
360 + help="hight setting (default: %(default))",
361 + )
362 + parser.add_argument("-o", "--output", help="Output path")
363 + args = parser.parse_args(argv)
364 + path="examples/"+args.image+"/"
365 + if not args.output:
366 + #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin")
367 + args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v3.bin"
368 + log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v3.txt"
369 +
370 + header = get_header(args.model, args.metric, args.quality)
371 + with Path(args.output).open("wb") as f:
372 + write_uchars(f, header)
373 + write_uints(f, (args.hight, args.width))
374 +
375 + with Path(log_path).open("w") as f:
376 + f.write(f"model : {args.model} | "
377 + f"quality : {args.quality} | "
378 + f"frames : {args.frame}\n")
379 + f.write( f"frame | bpp | "
380 + f" psnr |"
381 + f" Encoded time (model loading)\n"
382 + f" {0:3d} | ")
383 +
384 + total_psnr=0.0
385 + total_bpp=0.0
386 + total_time=0.0
387 + args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
388 + img=args.image+"_frame"+str(0)+".png"
389 + total_psnr, total_bpp, ref, total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
390 + for ff in range(1, args.frame):
391 + with Path(log_path).open("a") as f:
392 + f.write(f" {ff:3d} | ")
393 + img=args.image+"_frame"+str(ff)+".png"
394 +
395 + psnr, total_bpp, ref, time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
396 + total_psnr+=psnr
397 + total_time+=time
398 +
399 + total_psnr/=args.frame
400 + total_bpp/=args.frame
401 +
402 + with Path(log_path).open("a") as f:
403 + f.write( f"\n Total Encoded time: {total_time:.2f}s\n"
404 + f"\n Total PSNR: {total_psnr:.6f}\n"
405 + f" Total BPP: {total_bpp:.6f}\n")
406 + print(total_psnr)
407 + print(total_bpp)
408 +
409 +
410 +def decode(argv):
411 + parser = argparse.ArgumentParser(description="Decode bit-stream to imager")
412 + parser.add_argument("input", type=str)
413 + parser.add_argument(
414 + "-c",
415 + "--coder",
416 + choices=compressai.available_entropy_coders(),
417 + default=compressai.available_entropy_coders()[0],
418 + help="Entropy coder (default: %(default)s)",
419 + )
420 + parser.add_argument(
421 + "-f",
422 + "--frame",
423 + choices=list(range(1, 600)),
424 + type=int,
425 + default=100,
426 + help="Frame setting (default: %(default)s)",
427 + )
428 + parser.add_argument("--show", action="store_true")
429 + parser.add_argument("-o", "--output", help="Output path")
430 + args = parser.parse_args(argv)
431 +
432 + args.input="examples/"+args.input+"/"+args.input+"_768x768_"+str(args.frame//2)+"_8bit_444_v3.bin"
433 + args.output="examples/recon/"+args.output+"/"+args.output+"_768x768_"+str(50)+"_8bit_444"
434 + _decode(args.input, args.coder, args.show, args.frame, args.output)
435 +
436 +
437 +def parse_args(argv):
438 + parser = argparse.ArgumentParser(description="")
439 + parser.add_argument("command", choices=["encode", "decode"])
440 + args = parser.parse_args(argv)
441 + return args
442 +
443 +
444 +def main(argv):
445 + args = parse_args(argv[1:2])
446 + argv = argv[2:]
447 + torch.set_num_threads(1) # just to be sure
448 + if args.command == "encode":
449 + encode(argv)
450 + elif args.command == "decode":
451 + decode(argv)
452 +
453 +
454 +if __name__ == "__main__":
455 + main(sys.argv)