Showing
3 changed files
with
1339 additions
and
0 deletions
Our Encoder/codec-Copy1.py
0 → 100644
1 | +# Copyright 2020 InterDigital Communications, Inc. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import argparse | ||
16 | +import struct | ||
17 | +import sys | ||
18 | +import time | ||
19 | +import math | ||
20 | + | ||
21 | +from pathlib import Path | ||
22 | + | ||
23 | +import torch | ||
24 | +import torch.nn.functional as F | ||
25 | + | ||
26 | +from PIL import Image | ||
27 | +from torchvision.transforms import ToPILImage, ToTensor | ||
28 | + | ||
29 | +import compressai | ||
30 | + | ||
31 | +from compressai.zoo import models | ||
32 | + | ||
33 | +model_ids = {k: i for i, k in enumerate(models.keys())} | ||
34 | + | ||
35 | +metric_ids = { | ||
36 | + "mse": 0, | ||
37 | +} | ||
38 | + | ||
39 | + | ||
40 | +def inverse_dict(d): | ||
41 | + # We assume dict values are unique... | ||
42 | + assert len(d.keys()) == len(set(d.keys())) | ||
43 | + return {v: k for k, v in d.items()} | ||
44 | + | ||
45 | + | ||
46 | +def filesize(filepath: str) -> int: | ||
47 | + if not Path(filepath).is_file(): | ||
48 | + raise ValueError(f'Invalid file "{filepath}".') | ||
49 | + return Path(filepath).stat().st_size | ||
50 | + | ||
51 | + | ||
52 | +def load_image(filepath: str) -> Image.Image: | ||
53 | + return Image.open(filepath).convert("RGB") | ||
54 | + | ||
55 | + | ||
56 | +def img2torch(img: Image.Image) -> torch.Tensor: | ||
57 | + return ToTensor()(img).unsqueeze(0) | ||
58 | + | ||
59 | + | ||
60 | +def torch2img(x: torch.Tensor) -> Image.Image: | ||
61 | + return ToPILImage()(x.clamp_(0, 1).squeeze()) | ||
62 | + | ||
63 | + | ||
64 | +def write_uints(fd, values, fmt=">{:d}I"): | ||
65 | + fd.write(struct.pack(fmt.format(len(values)), *values)) | ||
66 | + | ||
67 | + | ||
68 | +def write_uchars(fd, values, fmt=">{:d}B"): | ||
69 | + fd.write(struct.pack(fmt.format(len(values)), *values)) | ||
70 | + | ||
71 | + | ||
72 | +def read_uints(fd, n, fmt=">{:d}I"): | ||
73 | + sz = struct.calcsize("I") | ||
74 | + return struct.unpack(fmt.format(n), fd.read(n * sz)) | ||
75 | + | ||
76 | + | ||
77 | +def read_uchars(fd, n, fmt=">{:d}B"): | ||
78 | + sz = struct.calcsize("B") | ||
79 | + return struct.unpack(fmt.format(n), fd.read(n * sz)) | ||
80 | + | ||
81 | + | ||
82 | +def write_bytes(fd, values, fmt=">{:d}s"): | ||
83 | + if len(values) == 0: | ||
84 | + return | ||
85 | + fd.write(struct.pack(fmt.format(len(values)), values)) | ||
86 | + | ||
87 | + | ||
88 | +def read_bytes(fd, n, fmt=">{:d}s"): | ||
89 | + sz = struct.calcsize("s") | ||
90 | + return struct.unpack(fmt.format(n), fd.read(n * sz))[0] | ||
91 | + | ||
92 | + | ||
93 | +def get_header(model_name, metric, quality): | ||
94 | + """Format header information: | ||
95 | + - 1 byte for model id | ||
96 | + - 4 bits for metric | ||
97 | + - 4 bits for quality param | ||
98 | + """ | ||
99 | + metric = metric_ids[metric] | ||
100 | + code = (metric << 4) | (quality - 1 & 0x0F) | ||
101 | + return model_ids[model_name], code | ||
102 | + | ||
103 | + | ||
104 | +def parse_header(header): | ||
105 | + """Read header information from 2 bytes: | ||
106 | + - 1 byte for model id | ||
107 | + - 4 bits for metric | ||
108 | + - 4 bits for quality param | ||
109 | + """ | ||
110 | + model_id, code = header | ||
111 | + quality = (code & 0x0F) + 1 | ||
112 | + metric = code >> 4 | ||
113 | + return ( | ||
114 | + inverse_dict(model_ids)[model_id], | ||
115 | + inverse_dict(metric_ids)[metric], | ||
116 | + quality, | ||
117 | + ) | ||
118 | + | ||
119 | + | ||
120 | +def pad(x, p=2 ** 6): | ||
121 | + h, w = x.size(2), x.size(3) | ||
122 | + H = (h + p - 1) // p * p | ||
123 | + W = (w + p - 1) // p * p | ||
124 | + padding_left = (W - w) // 2 | ||
125 | + padding_right = W - w - padding_left | ||
126 | + padding_top = (H - h) // 2 | ||
127 | + padding_bottom = H - h - padding_top | ||
128 | + return F.pad( | ||
129 | + x, | ||
130 | + (padding_left, padding_right, padding_top, padding_bottom), | ||
131 | + mode="constant", | ||
132 | + value=0, | ||
133 | + ) | ||
134 | + | ||
135 | + | ||
136 | +def crop(x, size): | ||
137 | + H, W = x.size(2), x.size(3) | ||
138 | + h, w = size | ||
139 | + padding_left = (W - w) // 2 | ||
140 | + padding_right = W - w - padding_left | ||
141 | + padding_top = (H - h) // 2 | ||
142 | + padding_bottom = H - h - padding_top | ||
143 | + return F.pad( | ||
144 | + x, | ||
145 | + (-padding_left, -padding_right, -padding_top, -padding_bottom), | ||
146 | + mode="constant", | ||
147 | + value=0, | ||
148 | + ) | ||
149 | + | ||
150 | +def compute_psnr(a, b): | ||
151 | + mse = torch.mean((a - b)**2).item() | ||
152 | + return -10 * math.log10(mse) | ||
153 | + | ||
154 | +def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path): | ||
155 | + compressai.set_entropy_coder(coder) | ||
156 | + enc_start = time.time() | ||
157 | + | ||
158 | + img = load_image(image) | ||
159 | + start = time.time() | ||
160 | + net = models[model](quality=quality, metric=metric, pretrained=True).eval() | ||
161 | + load_time = time.time() - start | ||
162 | + | ||
163 | + x = img2torch(img) | ||
164 | + h, w = x.size(2), x.size(3) | ||
165 | + p = 64 # maximum 6 strides of 2 | ||
166 | + x = pad(x, p) | ||
167 | + | ||
168 | +# header = get_header(model, metric, quality) | ||
169 | + | ||
170 | + strings = [] | ||
171 | + | ||
172 | + with torch.no_grad(): | ||
173 | + out = net.compress(x) | ||
174 | + shape = out["shape"] | ||
175 | + with Path(output).open("ab") as f: | ||
176 | + # write shape and number of encoded latents | ||
177 | + write_uints(f, (shape[0], shape[1], len(out["strings"]))) | ||
178 | + | ||
179 | + for s in out["strings"]: | ||
180 | + write_uints(f, (len(s[0]),)) | ||
181 | + write_bytes(f, s[0]) | ||
182 | + strings.append([s[0]]) | ||
183 | + | ||
184 | + with torch.no_grad(): | ||
185 | + recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | ||
186 | + x_recon = crop(recon_out["x_hat"], (h, w)) | ||
187 | + | ||
188 | + psnr=compute_psnr(x, x_recon) | ||
189 | + | ||
190 | + if i==False: | ||
191 | + diff=x-ref | ||
192 | + diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 | ||
193 | + diff_img = torch2img(diff1) | ||
194 | + diff_img.save(path+"recon/diff_v1_"+str(ff)+"_q"+str(quality)+".png") | ||
195 | + | ||
196 | + enc_time = time.time() - enc_start | ||
197 | + size = filesize(output) | ||
198 | + bpp = float(size) * 8 / (img.size[0] * img.size[1]*3) | ||
199 | + with Path(log_path).open("a") as f: | ||
200 | + f.write( f" {bpp-total_bpp:.4f} | " | ||
201 | + f" {psnr:.4f} |" | ||
202 | + f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n") | ||
203 | + recon_img = torch2img(x_recon) | ||
204 | + recon_img.save(path+"recon/v1_recon"+str(ff)+"_q"+str(quality)+".png") | ||
205 | + | ||
206 | + return psnr, bpp, x_recon, enc_time | ||
207 | + | ||
208 | + | ||
209 | +def _decode(inputpath, coder, show, frame, output=None): | ||
210 | + compressai.set_entropy_coder(coder) | ||
211 | + dec_start = time.time() | ||
212 | + | ||
213 | + with Path(inputpath).open("rb") as f: | ||
214 | + model, metric, quality = parse_header(read_uchars(f, 2)) | ||
215 | + print(f"Model: {model:s}, metric: {metric:s}, quality: {quality:d}") | ||
216 | + | ||
217 | + for i in range(frame): | ||
218 | + original_size = read_uints(f, 2) | ||
219 | + shape = read_uints(f, 2) | ||
220 | + strings = [] | ||
221 | + n_strings = read_uints(f, 1)[0] | ||
222 | + for _ in range(n_strings): | ||
223 | + s = read_bytes(f, read_uints(f, 1)[0]) | ||
224 | + strings.append([s]) | ||
225 | + | ||
226 | + start = time.time() | ||
227 | + net = models[model](quality=quality, metric=metric, pretrained=True).eval() | ||
228 | + load_time = time.time() - start | ||
229 | + | ||
230 | + with torch.no_grad(): | ||
231 | + out = net.decompress(strings, shape) | ||
232 | + | ||
233 | + x_hat = crop(out["x_hat"], original_size) | ||
234 | + img = torch2img(x_hat) | ||
235 | + dec_time = time.time() - dec_start | ||
236 | + print(f"Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)") | ||
237 | + | ||
238 | + if show: | ||
239 | + show_image(img) | ||
240 | + if output is not None: | ||
241 | + img.save(output+"_frame"+str(i)+".png") | ||
242 | + | ||
243 | +def show_image(img: Image.Image): | ||
244 | + from matplotlib import pyplot as plt | ||
245 | + | ||
246 | + fig, ax = plt.subplots() | ||
247 | + ax.axis("off") | ||
248 | + ax.title.set_text("Decoded image") | ||
249 | + ax.imshow(img) | ||
250 | + fig.tight_layout() | ||
251 | + plt.show() | ||
252 | + | ||
253 | + | ||
254 | +def encode(argv): | ||
255 | + parser = argparse.ArgumentParser(description="Encode image to bit-stream") | ||
256 | + parser.add_argument("image", type=str) | ||
257 | + parser.add_argument( | ||
258 | + "--model", | ||
259 | + choices=models.keys(), | ||
260 | + default=list(models.keys())[0], | ||
261 | + help="NN model to use (default: %(default)s)", | ||
262 | + ) | ||
263 | + parser.add_argument( | ||
264 | + "-m", | ||
265 | + "--metric", | ||
266 | + choices=["mse"], | ||
267 | + default="mse", | ||
268 | + help="metric trained against (default: %(default)s", | ||
269 | + ) | ||
270 | + parser.add_argument( | ||
271 | + "-q", | ||
272 | + "--quality", | ||
273 | + choices=list(range(1, 9)), | ||
274 | + type=int, | ||
275 | + default=3, | ||
276 | + help="Quality setting (default: %(default)s)", | ||
277 | + ) | ||
278 | + parser.add_argument( | ||
279 | + "-c", | ||
280 | + "--coder", | ||
281 | + choices=compressai.available_entropy_coders(), | ||
282 | + default=compressai.available_entropy_coders()[0], | ||
283 | + help="Entropy coder (default: %(default)s)", | ||
284 | + ) | ||
285 | + parser.add_argument( | ||
286 | + "-f", | ||
287 | + "--frame", | ||
288 | + type=int, | ||
289 | + default=100, | ||
290 | + help="Frame setting (default: %(default)s)", | ||
291 | + ) | ||
292 | + parser.add_argument( | ||
293 | + "-fr", | ||
294 | + "--framerate", | ||
295 | + choices=[60,50,24], | ||
296 | + type=int, | ||
297 | + default=50, | ||
298 | + help="Frame rate setting (default: %(default)s)", | ||
299 | + ) | ||
300 | + parser.add_argument( | ||
301 | + "-width", | ||
302 | + "--width", | ||
303 | + type=int, | ||
304 | + default=768, | ||
305 | + help="width setting (default: %(default))", | ||
306 | + ) | ||
307 | + parser.add_argument( | ||
308 | + "-hight", | ||
309 | + "--hight", | ||
310 | + type=int, | ||
311 | + default=768, | ||
312 | + help="hight setting (default: %(default))", | ||
313 | + ) | ||
314 | + parser.add_argument("-o", "--output", help="Output path") | ||
315 | + args = parser.parse_args(argv) | ||
316 | + path="examples/"+args.image+"/" | ||
317 | + if not args.output: | ||
318 | + #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin") | ||
319 | + args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v1.bin" | ||
320 | + log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v1.txt" | ||
321 | + | ||
322 | + header = get_header(args.model, args.metric, args.quality) | ||
323 | + with Path(args.output).open("wb") as f: | ||
324 | + write_uchars(f, header) | ||
325 | + write_uints(f, (args.hight, args.width)) | ||
326 | + | ||
327 | + with Path(log_path).open("w") as f: | ||
328 | + f.write(f"model : {args.model} | " | ||
329 | + f"quality : {args.quality} | " | ||
330 | + f"frames : {args.frame}\n") | ||
331 | + f.write( f"frame | bpp | " | ||
332 | + f" psnr |" | ||
333 | + f" Encoded time (model loading)\n" | ||
334 | + f" {0:3d} | ") | ||
335 | + | ||
336 | + total_psnr=0.0 | ||
337 | + total_bpp=0.0 | ||
338 | + total_time=0.0 | ||
339 | + args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" | ||
340 | + img=args.image+"_frame"+str(0)+".png" | ||
341 | + total_psnr, total_bpp, ref,total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) | ||
342 | + for ff in range(1, args.frame): | ||
343 | + with Path(log_path).open("a") as f: | ||
344 | + f.write(f" {ff:3d} | ") | ||
345 | + img=args.image+"_frame"+str(ff)+".png" | ||
346 | + | ||
347 | + psnr, total_bpp, ref,time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) | ||
348 | + total_psnr+=psnr | ||
349 | + total_time+=time | ||
350 | + | ||
351 | + total_psnr/=args.frame | ||
352 | + total_bpp/=args.frame | ||
353 | + | ||
354 | + with Path(log_path).open("a") as f: | ||
355 | + f.write( f"\n Total Encoded time: {total_time:.2f}s\n" | ||
356 | + f"\n Total PSNR: {total_psnr:.6f}\n" | ||
357 | + f" Total BPP: {total_bpp:.6f}\n") | ||
358 | + print(total_psnr) | ||
359 | + print(total_bpp) | ||
360 | + | ||
361 | + | ||
362 | +def decode(argv): | ||
363 | + parser = argparse.ArgumentParser(description="Decode bit-stream to imager") | ||
364 | + parser.add_argument("input", type=str) | ||
365 | + parser.add_argument( | ||
366 | + "-c", | ||
367 | + "--coder", | ||
368 | + choices=compressai.available_entropy_coders(), | ||
369 | + default=compressai.available_entropy_coders()[0], | ||
370 | + help="Entropy coder (default: %(default)s)", | ||
371 | + ) | ||
372 | + parser.add_argument( | ||
373 | + "-f", | ||
374 | + "--frame", | ||
375 | + choices=list(range(1, 600)), | ||
376 | + type=int, | ||
377 | + default=100, | ||
378 | + help="Frame setting (default: %(default)s)", | ||
379 | + ) | ||
380 | + parser.add_argument("--show", action="store_true") | ||
381 | + parser.add_argument("-o", "--output", help="Output path") | ||
382 | + args = parser.parse_args(argv) | ||
383 | + | ||
384 | + args.input="examples/"+args.input+"/"+args.input+"_768x768_"+str(args.frame//2)+"_8bit_444.bin" | ||
385 | + args.output="examples/recon/"+args.output+"/"+args.output+"_768x768_"+str(50)+"_8bit_444" | ||
386 | + _decode(args.input, args.coder, args.show, args.frame, args.output) | ||
387 | + | ||
388 | + | ||
389 | +def parse_args(argv): | ||
390 | + parser = argparse.ArgumentParser(description="") | ||
391 | + parser.add_argument("command", choices=["encode", "decode"]) | ||
392 | + args = parser.parse_args(argv) | ||
393 | + return args | ||
394 | + | ||
395 | + | ||
396 | +def main(argv): | ||
397 | + args = parse_args(argv[1:2]) | ||
398 | + argv = argv[2:] | ||
399 | + torch.set_num_threads(1) # just to be sure | ||
400 | + if args.command == "encode": | ||
401 | + encode(argv) | ||
402 | + elif args.command == "decode": | ||
403 | + decode(argv) | ||
404 | + | ||
405 | + | ||
406 | +if __name__ == "__main__": | ||
407 | + main(sys.argv) |
Our Encoder/codec-Copy2.py
0 → 100644
1 | +# Copyright 2020 InterDigital Communications, Inc. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import argparse | ||
16 | +import struct | ||
17 | +import sys | ||
18 | +import time | ||
19 | +import math | ||
20 | + | ||
21 | +from pathlib import Path | ||
22 | + | ||
23 | +import torch | ||
24 | +import torch.nn.functional as F | ||
25 | + | ||
26 | +from PIL import Image | ||
27 | +from torchvision.transforms import ToPILImage, ToTensor | ||
28 | + | ||
29 | +import compressai | ||
30 | + | ||
31 | +from compressai.zoo import models | ||
32 | + | ||
33 | +model_ids = {k: i for i, k in enumerate(models.keys())} | ||
34 | + | ||
35 | +metric_ids = { | ||
36 | + "mse": 0, | ||
37 | +} | ||
38 | + | ||
39 | + | ||
40 | +def inverse_dict(d): | ||
41 | + # We assume dict values are unique... | ||
42 | + assert len(d.keys()) == len(set(d.keys())) | ||
43 | + return {v: k for k, v in d.items()} | ||
44 | + | ||
45 | + | ||
46 | +def filesize(filepath: str) -> int: | ||
47 | + if not Path(filepath).is_file(): | ||
48 | + raise ValueError(f'Invalid file "{filepath}".') | ||
49 | + return Path(filepath).stat().st_size | ||
50 | + | ||
51 | + | ||
52 | +def load_image(filepath: str) -> Image.Image: | ||
53 | + return Image.open(filepath).convert("RGB") | ||
54 | + | ||
55 | + | ||
56 | +def img2torch(img: Image.Image) -> torch.Tensor: | ||
57 | + return ToTensor()(img).unsqueeze(0) | ||
58 | + | ||
59 | + | ||
60 | +def torch2img(x: torch.Tensor) -> Image.Image: | ||
61 | + return ToPILImage()(x.clamp_(0, 1).squeeze()) | ||
62 | + | ||
63 | + | ||
64 | +def write_uints(fd, values, fmt=">{:d}I"): | ||
65 | + fd.write(struct.pack(fmt.format(len(values)), *values)) | ||
66 | + | ||
67 | + | ||
68 | +def write_uchars(fd, values, fmt=">{:d}B"): | ||
69 | + fd.write(struct.pack(fmt.format(len(values)), *values)) | ||
70 | + | ||
71 | + | ||
72 | +def read_uints(fd, n, fmt=">{:d}I"): | ||
73 | + sz = struct.calcsize("I") | ||
74 | + return struct.unpack(fmt.format(n), fd.read(n * sz)) | ||
75 | + | ||
76 | + | ||
77 | +def read_uchars(fd, n, fmt=">{:d}B"): | ||
78 | + sz = struct.calcsize("B") | ||
79 | + return struct.unpack(fmt.format(n), fd.read(n * sz)) | ||
80 | + | ||
81 | + | ||
82 | +def write_bytes(fd, values, fmt=">{:d}s"): | ||
83 | + if len(values) == 0: | ||
84 | + return | ||
85 | + fd.write(struct.pack(fmt.format(len(values)), values)) | ||
86 | + | ||
87 | + | ||
88 | +def read_bytes(fd, n, fmt=">{:d}s"): | ||
89 | + sz = struct.calcsize("s") | ||
90 | + return struct.unpack(fmt.format(n), fd.read(n * sz))[0] | ||
91 | + | ||
92 | + | ||
93 | +def get_header(model_name, metric, quality): | ||
94 | + """Format header information: | ||
95 | + - 1 byte for model id | ||
96 | + - 4 bits for metric | ||
97 | + - 4 bits for quality param | ||
98 | + """ | ||
99 | + metric = metric_ids[metric] | ||
100 | + code = (metric << 4) | (quality - 1 & 0x0F) | ||
101 | + return model_ids[model_name], code | ||
102 | + | ||
103 | + | ||
104 | +def parse_header(header): | ||
105 | + """Read header information from 2 bytes: | ||
106 | + - 1 byte for model id | ||
107 | + - 4 bits for metric | ||
108 | + - 4 bits for quality param | ||
109 | + """ | ||
110 | + model_id, code = header | ||
111 | + quality = (code & 0x0F) + 1 | ||
112 | + metric = code >> 4 | ||
113 | + return ( | ||
114 | + inverse_dict(model_ids)[model_id], | ||
115 | + inverse_dict(metric_ids)[metric], | ||
116 | + quality, | ||
117 | + ) | ||
118 | + | ||
119 | + | ||
120 | +def pad(x, p=2 ** 6): | ||
121 | + h, w = x.size(2), x.size(3) | ||
122 | + H = (h + p - 1) // p * p | ||
123 | + W = (w + p - 1) // p * p | ||
124 | + padding_left = (W - w) // 2 | ||
125 | + padding_right = W - w - padding_left | ||
126 | + padding_top = (H - h) // 2 | ||
127 | + padding_bottom = H - h - padding_top | ||
128 | + return F.pad( | ||
129 | + x, | ||
130 | + (padding_left, padding_right, padding_top, padding_bottom), | ||
131 | + mode="constant", | ||
132 | + value=0, | ||
133 | + ) | ||
134 | + | ||
135 | + | ||
136 | +def crop(x, size): | ||
137 | + H, W = x.size(2), x.size(3) | ||
138 | + h, w = size | ||
139 | + padding_left = (W - w) // 2 | ||
140 | + padding_right = W - w - padding_left | ||
141 | + padding_top = (H - h) // 2 | ||
142 | + padding_bottom = H - h - padding_top | ||
143 | + return F.pad( | ||
144 | + x, | ||
145 | + (-padding_left, -padding_right, -padding_top, -padding_bottom), | ||
146 | + mode="constant", | ||
147 | + value=0, | ||
148 | + ) | ||
149 | + | ||
150 | +def compute_psnr(a, b): | ||
151 | + mse = torch.mean((a - b)**2).item() | ||
152 | + return -10 * math.log10(mse) | ||
153 | + | ||
154 | +def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path): | ||
155 | + compressai.set_entropy_coder(coder) | ||
156 | + enc_start = time.time() | ||
157 | + | ||
158 | + img = load_image(image) | ||
159 | + start = time.time() | ||
160 | + net = models[model](quality=quality, metric=metric, pretrained=True).eval() | ||
161 | + load_time = time.time() - start | ||
162 | + | ||
163 | + x = img2torch(img) | ||
164 | + h, w = x.size(2), x.size(3) | ||
165 | + p = 64 # maximum 6 strides of 2 | ||
166 | + x = pad(x, p) | ||
167 | + | ||
168 | +# header = get_header(model, metric, quality) | ||
169 | + if i==True: | ||
170 | + strings = [] | ||
171 | + | ||
172 | + with torch.no_grad(): | ||
173 | + out = net.compress(x) | ||
174 | + shape = out["shape"] | ||
175 | + with Path(output).open("ab") as f: | ||
176 | + # write shape and number of encoded latents | ||
177 | + write_uints(f, (shape[0], shape[1], len(out["strings"]))) | ||
178 | + | ||
179 | + for s in out["strings"]: | ||
180 | + write_uints(f, (len(s[0]),)) | ||
181 | + write_bytes(f, s[0]) | ||
182 | + strings.append([s[0]]) | ||
183 | + | ||
184 | + with torch.no_grad(): | ||
185 | + recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | ||
186 | + x_recon = crop(recon_out["x_hat"], (h, w)) | ||
187 | + | ||
188 | + psnr=compute_psnr(x, x_recon) | ||
189 | + else: | ||
190 | + diff=x-ref | ||
191 | + #1 | ||
192 | + diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 | ||
193 | + | ||
194 | + #2 | ||
195 | + ''' | ||
196 | + diff1=torch.clamp(diff, min=0.0, max=1.0) | ||
197 | + diff2=-torch.clamp(diff, min=-1.0, max=0.0) | ||
198 | + | ||
199 | + diff1=pad(diff1, p) | ||
200 | + diff2=pad(diff2, p) | ||
201 | + ''' | ||
202 | + #1 | ||
203 | + | ||
204 | + with torch.no_grad(): | ||
205 | + out1 = net.compress(diff1) | ||
206 | + shape1 = out1["shape"] | ||
207 | + strings = [] | ||
208 | + | ||
209 | + with Path(output).open("ab") as f: | ||
210 | + # write shape and number of encoded latents | ||
211 | + write_uints(f, (shape1[0], shape1[1], len(out1["strings"]))) | ||
212 | + | ||
213 | + for s in out1["strings"]: | ||
214 | + write_uints(f, (len(s[0]),)) | ||
215 | + write_bytes(f, s[0]) | ||
216 | + strings.append([s[0]]) | ||
217 | + | ||
218 | + with torch.no_grad(): | ||
219 | + recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) | ||
220 | + x_hat1 = crop(recon_out["x_hat"], (h, w)) | ||
221 | + | ||
222 | + #2 | ||
223 | + ''' | ||
224 | + with torch.no_grad(): | ||
225 | + out1 = net.compress(diff1) | ||
226 | + shape1 = out1["shape"] | ||
227 | + strings = [] | ||
228 | + | ||
229 | + with Path(output).open("ab") as f: | ||
230 | + # write shape and number of encoded latents | ||
231 | + write_uints(f, (shape1[0], shape1[1], len(out1["strings"]))) | ||
232 | + | ||
233 | + for s in out1["strings"]: | ||
234 | + write_uints(f, (len(s[0]),)) | ||
235 | + write_bytes(f, s[0]) | ||
236 | + strings.append([s[0]]) | ||
237 | + | ||
238 | + with torch.no_grad(): | ||
239 | + recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) | ||
240 | + x_hat1 = crop(recon_out["x_hat"], (h, w)) | ||
241 | + with torch.no_grad(): | ||
242 | + out = net.compress(diff2) | ||
243 | + shape = out["shape"] | ||
244 | + strings = [] | ||
245 | + | ||
246 | + with Path(output).open("ab") as f: | ||
247 | + # write shape and number of encoded latents | ||
248 | + write_uints(f, (shape[0], shape[1], len(out["strings"]))) | ||
249 | + | ||
250 | + for s in out["strings"]: | ||
251 | + write_uints(f, (len(s[0]),)) | ||
252 | + write_bytes(f, s[0]) | ||
253 | + strings.append([s[0]]) | ||
254 | + | ||
255 | + with torch.no_grad(): | ||
256 | + recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | ||
257 | + x_hat2 = crop(recon_out["x_hat"], (h, w)) | ||
258 | + x_recon=ref+x_hat1-x_hat2 | ||
259 | + ''' | ||
260 | + | ||
261 | + x_recon=ref+x_hat1-0.5 | ||
262 | + psnr=compute_psnr(x, x_recon) | ||
263 | + diff_img = torch2img(diff1) | ||
264 | + diff_img.save(path+"recon/diff"+str(ff)+"_q"+str(quality)+".png") | ||
265 | + | ||
266 | + enc_time = time.time() - enc_start | ||
267 | + size = filesize(output) | ||
268 | + bpp = float(size) * 8 / (img.size[0] * img.size[1]*3) | ||
269 | + with Path(log_path).open("a") as f: | ||
270 | + f.write( f" {bpp-total_bpp:.4f} | " | ||
271 | + f" {psnr:.4f} |" | ||
272 | + f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n") | ||
273 | + recon_img = torch2img(x_recon) | ||
274 | + recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png") | ||
275 | + | ||
276 | + return psnr, bpp, x_recon, enc_time | ||
277 | + | ||
278 | + | ||
279 | +def _decode(inputpath, coder, show, frame, output=None): | ||
280 | + compressai.set_entropy_coder(coder) | ||
281 | + dec_start = time.time() | ||
282 | + | ||
283 | + with Path(inputpath).open("rb") as f: | ||
284 | + model, metric, quality = parse_header(read_uchars(f, 2)) | ||
285 | + print(f"Model: {model:s}, metric: {metric:s}, quality: {quality:d}") | ||
286 | + | ||
287 | + for i in range(frame): | ||
288 | + original_size = read_uints(f, 2) | ||
289 | + shape = read_uints(f, 2) | ||
290 | + strings = [] | ||
291 | + n_strings = read_uints(f, 1)[0] | ||
292 | + for _ in range(n_strings): | ||
293 | + s = read_bytes(f, read_uints(f, 1)[0]) | ||
294 | + strings.append([s]) | ||
295 | + | ||
296 | + start = time.time() | ||
297 | + net = models[model](quality=quality, metric=metric, pretrained=True).eval() | ||
298 | + load_time = time.time() - start | ||
299 | + | ||
300 | + with torch.no_grad(): | ||
301 | + out = net.decompress(strings, shape) | ||
302 | + | ||
303 | + x_hat = crop(out["x_hat"], original_size) | ||
304 | + img = torch2img(x_hat) | ||
305 | + dec_time = time.time() - dec_start | ||
306 | + print(f"Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)") | ||
307 | + | ||
308 | + if show: | ||
309 | + show_image(img) | ||
310 | + if output is not None: | ||
311 | + img.save(output+"_frame"+str(i)+".png") | ||
312 | + | ||
313 | +def show_image(img: Image.Image): | ||
314 | + from matplotlib import pyplot as plt | ||
315 | + | ||
316 | + fig, ax = plt.subplots() | ||
317 | + ax.axis("off") | ||
318 | + ax.title.set_text("Decoded image") | ||
319 | + ax.imshow(img) | ||
320 | + fig.tight_layout() | ||
321 | + plt.show() | ||
322 | + | ||
323 | + | ||
324 | +def encode(argv): | ||
325 | + parser = argparse.ArgumentParser(description="Encode image to bit-stream") | ||
326 | + parser.add_argument("image", type=str) | ||
327 | + parser.add_argument( | ||
328 | + "--model", | ||
329 | + choices=models.keys(), | ||
330 | + default=list(models.keys())[0], | ||
331 | + help="NN model to use (default: %(default)s)", | ||
332 | + ) | ||
333 | + parser.add_argument( | ||
334 | + "-m", | ||
335 | + "--metric", | ||
336 | + choices=["mse"], | ||
337 | + default="mse", | ||
338 | + help="metric trained against (default: %(default)s", | ||
339 | + ) | ||
340 | + parser.add_argument( | ||
341 | + "-q", | ||
342 | + "--quality", | ||
343 | + choices=list(range(1, 9)), | ||
344 | + type=int, | ||
345 | + default=3, | ||
346 | + help="Quality setting (default: %(default)s)", | ||
347 | + ) | ||
348 | + parser.add_argument( | ||
349 | + "-c", | ||
350 | + "--coder", | ||
351 | + choices=compressai.available_entropy_coders(), | ||
352 | + default=compressai.available_entropy_coders()[0], | ||
353 | + help="Entropy coder (default: %(default)s)", | ||
354 | + ) | ||
355 | + parser.add_argument( | ||
356 | + "-f", | ||
357 | + "--frame", | ||
358 | + type=int, | ||
359 | + default=100, | ||
360 | + help="Frame setting (default: %(default)s)", | ||
361 | + ) | ||
362 | + parser.add_argument( | ||
363 | + "-fr", | ||
364 | + "--framerate", | ||
365 | + choices=[60,50,24], | ||
366 | + type=int, | ||
367 | + default=50, | ||
368 | + help="Frame rate setting (default: %(default)s)", | ||
369 | + ) | ||
370 | + parser.add_argument( | ||
371 | + "-width", | ||
372 | + "--width", | ||
373 | + type=int, | ||
374 | + default=768, | ||
375 | + help="width setting (default: %(default))", | ||
376 | + ) | ||
377 | + parser.add_argument( | ||
378 | + "-height", | ||
379 | + "--height", | ||
380 | + type=int, | ||
381 | + default=768, | ||
382 | + help="hight setting (default: %(default))", | ||
383 | + ) | ||
384 | + parser.add_argument("-o", "--output", help="Output path") | ||
385 | + args = parser.parse_args(argv) | ||
386 | + path="examples/"+args.image+"/" | ||
387 | + if not args.output: | ||
388 | + #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin") | ||
389 | + args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.bin" | ||
390 | + log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.txt" | ||
391 | + | ||
392 | + header = get_header(args.model, args.metric, args.quality) | ||
393 | + with Path(args.output).open("wb") as f: | ||
394 | + write_uchars(f, header) | ||
395 | + write_uints(f, (args.height, args.width)) | ||
396 | + | ||
397 | + with Path(log_path).open("w") as f: | ||
398 | + f.write(f"model : {args.model} | " | ||
399 | + f"quality : {args.quality} | " | ||
400 | + f"frames : {args.frame}\n") | ||
401 | + f.write( f"frame | bpp | " | ||
402 | + f" psnr |" | ||
403 | + f" Encoded time (model loading)\n" | ||
404 | + f" {0:3d} | ") | ||
405 | + | ||
406 | + total_psnr=0.0 | ||
407 | + total_bpp=0.0 | ||
408 | + total_time=0.0 | ||
409 | + args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" | ||
410 | + img=args.image+"_frame"+str(0)+".png" | ||
411 | + total_psnr, total_bpp, ref, total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) | ||
412 | + for ff in range(1, args.frame): | ||
413 | + with Path(log_path).open("a") as f: | ||
414 | + f.write(f" {ff:3d} | ") | ||
415 | + img=args.image+"_frame"+str(ff)+".png" | ||
416 | + | ||
417 | + psnr, total_bpp, ref, time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) | ||
418 | + total_psnr+=psnr | ||
419 | + total_time+=time | ||
420 | + | ||
421 | + total_psnr/=args.frame | ||
422 | + total_bpp/=args.frame | ||
423 | + | ||
424 | + with Path(log_path).open("a") as f: | ||
425 | + f.write( f"\n Total Encoded time: {total_time:.2f}s\n" | ||
426 | + f"\n Total PSNR: {total_psnr:.6f}\n" | ||
427 | + f" Total BPP: {total_bpp:.6f}\n") | ||
428 | + print(total_psnr) | ||
429 | + print(total_bpp) | ||
430 | + | ||
431 | + | ||
432 | +def decode(argv): | ||
433 | + parser = argparse.ArgumentParser(description="Decode bit-stream to imager") | ||
434 | + parser.add_argument("input", type=str) | ||
435 | + parser.add_argument( | ||
436 | + "-c", | ||
437 | + "--coder", | ||
438 | + choices=compressai.available_entropy_coders(), | ||
439 | + default=compressai.available_entropy_coders()[0], | ||
440 | + help="Entropy coder (default: %(default)s)", | ||
441 | + ) | ||
442 | + parser.add_argument( | ||
443 | + "-f", | ||
444 | + "--frame", | ||
445 | + choices=list(range(1, 600)), | ||
446 | + type=int, | ||
447 | + default=100, | ||
448 | + help="Frame setting (default: %(default)s)", | ||
449 | + ) | ||
450 | + parser.add_argument("--show", action="store_true") | ||
451 | + parser.add_argument("-o", "--output", help="Output path") | ||
452 | + args = parser.parse_args(argv) | ||
453 | + | ||
454 | + args.input="examples/"+args.input+"/"+args.input+"_768x768_"+str(args.frame//2)+"_8bit_444.bin" | ||
455 | + args.output="examples/recon/"+args.output+"/"+args.output+"_768x768_"+str(50)+"_8bit_444" | ||
456 | + _decode(args.input, args.coder, args.show, args.frame, args.output) | ||
457 | + | ||
458 | + | ||
459 | +def parse_args(argv): | ||
460 | + parser = argparse.ArgumentParser(description="") | ||
461 | + parser.add_argument("command", choices=["encode", "decode"]) | ||
462 | + args = parser.parse_args(argv) | ||
463 | + return args | ||
464 | + | ||
465 | + | ||
466 | +def main(argv): | ||
467 | + args = parse_args(argv[1:2]) | ||
468 | + argv = argv[2:] | ||
469 | + torch.set_num_threads(1) # just to be sure | ||
470 | + if args.command == "encode": | ||
471 | + encode(argv) | ||
472 | + elif args.command == "decode": | ||
473 | + decode(argv) | ||
474 | + | ||
475 | + | ||
476 | +if __name__ == "__main__": | ||
477 | + main(sys.argv) |
Our Encoder/codec-Copy3.py
0 → 100644
1 | +# Copyright 2020 InterDigital Communications, Inc. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import argparse | ||
16 | +import struct | ||
17 | +import sys | ||
18 | +import time | ||
19 | +import math | ||
20 | + | ||
21 | +from pathlib import Path | ||
22 | + | ||
23 | +import torch | ||
24 | +import torch.nn.functional as F | ||
25 | + | ||
26 | +from PIL import Image | ||
27 | +from torchvision.transforms import ToPILImage, ToTensor | ||
28 | + | ||
29 | +import compressai | ||
30 | + | ||
31 | +from compressai.zoo import models | ||
32 | + | ||
33 | +model_ids = {k: i for i, k in enumerate(models.keys())} | ||
34 | + | ||
35 | +metric_ids = { | ||
36 | + "mse": 0, | ||
37 | +} | ||
38 | + | ||
39 | + | ||
40 | +def inverse_dict(d): | ||
41 | + # We assume dict values are unique... | ||
42 | + assert len(d.keys()) == len(set(d.keys())) | ||
43 | + return {v: k for k, v in d.items()} | ||
44 | + | ||
45 | + | ||
46 | +def filesize(filepath: str) -> int: | ||
47 | + if not Path(filepath).is_file(): | ||
48 | + raise ValueError(f'Invalid file "{filepath}".') | ||
49 | + return Path(filepath).stat().st_size | ||
50 | + | ||
51 | + | ||
52 | +def load_image(filepath: str) -> Image.Image: | ||
53 | + return Image.open(filepath).convert("RGB") | ||
54 | + | ||
55 | + | ||
56 | +def img2torch(img: Image.Image) -> torch.Tensor: | ||
57 | + return ToTensor()(img).unsqueeze(0) | ||
58 | + | ||
59 | + | ||
60 | +def torch2img(x: torch.Tensor) -> Image.Image: | ||
61 | + return ToPILImage()(x.clamp_(0, 1).squeeze()) | ||
62 | + | ||
63 | + | ||
64 | +def write_uints(fd, values, fmt=">{:d}I"): | ||
65 | + fd.write(struct.pack(fmt.format(len(values)), *values)) | ||
66 | + | ||
67 | + | ||
68 | +def write_uchars(fd, values, fmt=">{:d}B"): | ||
69 | + fd.write(struct.pack(fmt.format(len(values)), *values)) | ||
70 | + | ||
71 | + | ||
72 | +def read_uints(fd, n, fmt=">{:d}I"): | ||
73 | + sz = struct.calcsize("I") | ||
74 | + return struct.unpack(fmt.format(n), fd.read(n * sz)) | ||
75 | + | ||
76 | + | ||
77 | +def read_uchars(fd, n, fmt=">{:d}B"): | ||
78 | + sz = struct.calcsize("B") | ||
79 | + return struct.unpack(fmt.format(n), fd.read(n * sz)) | ||
80 | + | ||
81 | + | ||
82 | +def write_bytes(fd, values, fmt=">{:d}s"): | ||
83 | + if len(values) == 0: | ||
84 | + return | ||
85 | + fd.write(struct.pack(fmt.format(len(values)), values)) | ||
86 | + | ||
87 | + | ||
88 | +def read_bytes(fd, n, fmt=">{:d}s"): | ||
89 | + sz = struct.calcsize("s") | ||
90 | + return struct.unpack(fmt.format(n), fd.read(n * sz))[0] | ||
91 | + | ||
92 | + | ||
93 | +def get_header(model_name, metric, quality): | ||
94 | + """Format header information: | ||
95 | + - 1 byte for model id | ||
96 | + - 4 bits for metric | ||
97 | + - 4 bits for quality param | ||
98 | + """ | ||
99 | + metric = metric_ids[metric] | ||
100 | + code = (metric << 4) | (quality - 1 & 0x0F) | ||
101 | + return model_ids[model_name], code | ||
102 | + | ||
103 | + | ||
104 | +def parse_header(header): | ||
105 | + """Read header information from 2 bytes: | ||
106 | + - 1 byte for model id | ||
107 | + - 4 bits for metric | ||
108 | + - 4 bits for quality param | ||
109 | + """ | ||
110 | + model_id, code = header | ||
111 | + quality = (code & 0x0F) + 1 | ||
112 | + metric = code >> 4 | ||
113 | + return ( | ||
114 | + inverse_dict(model_ids)[model_id], | ||
115 | + inverse_dict(metric_ids)[metric], | ||
116 | + quality, | ||
117 | + ) | ||
118 | + | ||
119 | + | ||
120 | +def pad(x, p=2 ** 6): | ||
121 | + h, w = x.size(2), x.size(3) | ||
122 | + H = (h + p - 1) // p * p | ||
123 | + W = (w + p - 1) // p * p | ||
124 | + padding_left = (W - w) // 2 | ||
125 | + padding_right = W - w - padding_left | ||
126 | + padding_top = (H - h) // 2 | ||
127 | + padding_bottom = H - h - padding_top | ||
128 | + return F.pad( | ||
129 | + x, | ||
130 | + (padding_left, padding_right, padding_top, padding_bottom), | ||
131 | + mode="constant", | ||
132 | + value=0, | ||
133 | + ) | ||
134 | + | ||
135 | + | ||
136 | +def crop(x, size): | ||
137 | + H, W = x.size(2), x.size(3) | ||
138 | + h, w = size | ||
139 | + padding_left = (W - w) // 2 | ||
140 | + padding_right = W - w - padding_left | ||
141 | + padding_top = (H - h) // 2 | ||
142 | + padding_bottom = H - h - padding_top | ||
143 | + return F.pad( | ||
144 | + x, | ||
145 | + (-padding_left, -padding_right, -padding_top, -padding_bottom), | ||
146 | + mode="constant", | ||
147 | + value=0, | ||
148 | + ) | ||
149 | + | ||
150 | +def compute_psnr(a, b): | ||
151 | + mse = torch.mean((a - b)**2).item() | ||
152 | + return -10 * math.log10(mse) | ||
153 | + | ||
154 | +def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path): | ||
155 | + compressai.set_entropy_coder(coder) | ||
156 | + enc_start = time.time() | ||
157 | + | ||
158 | + img = load_image(image) | ||
159 | + start = time.time() | ||
160 | + net = models[model](quality=quality, metric=metric, pretrained=True).eval() | ||
161 | + load_time = time.time() - start | ||
162 | + | ||
163 | + x = img2torch(img) | ||
164 | + h, w = x.size(2), x.size(3) | ||
165 | + p = 64 # maximum 6 strides of 2 | ||
166 | + x = pad(x, p) | ||
167 | + | ||
168 | +# header = get_header(model, metric, quality) | ||
169 | + if i==True: | ||
170 | + strings = [] | ||
171 | + | ||
172 | + with torch.no_grad(): | ||
173 | + out = net.compress(x) | ||
174 | + shape = out["shape"] | ||
175 | + with Path(output).open("ab") as f: | ||
176 | + # write shape and number of encoded latents | ||
177 | + write_uints(f, (shape[0], shape[1], len(out["strings"]))) | ||
178 | + | ||
179 | + for s in out["strings"]: | ||
180 | + write_uints(f, (len(s[0]),)) | ||
181 | + write_bytes(f, s[0]) | ||
182 | + strings.append([s[0]]) | ||
183 | + | ||
184 | + with torch.no_grad(): | ||
185 | + recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | ||
186 | + x_recon = crop(recon_out["x_hat"], (h, w)) | ||
187 | + | ||
188 | + psnr=compute_psnr(x, x_recon) | ||
189 | + else: | ||
190 | + diff=x-ref | ||
191 | + #2 | ||
192 | + | ||
193 | + diff1=torch.clamp(diff, min=0.0, max=1.0) | ||
194 | + diff2=-torch.clamp(diff, min=-1.0, max=0.0) | ||
195 | + | ||
196 | + diff1=pad(diff1, p) | ||
197 | + diff2=pad(diff2, p) | ||
198 | + | ||
199 | + #2 | ||
200 | + | ||
201 | + with torch.no_grad(): | ||
202 | + out1 = net.compress(diff1) | ||
203 | + shape1 = out1["shape"] | ||
204 | + strings = [] | ||
205 | + | ||
206 | + with Path(output).open("ab") as f: | ||
207 | + # write shape and number of encoded latents | ||
208 | + write_uints(f, (shape1[0], shape1[1], len(out1["strings"]))) | ||
209 | + | ||
210 | + for s in out1["strings"]: | ||
211 | + write_uints(f, (len(s[0]),)) | ||
212 | + write_bytes(f, s[0]) | ||
213 | + strings.append([s[0]]) | ||
214 | + | ||
215 | + with torch.no_grad(): | ||
216 | + recon_out1 = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) | ||
217 | + x_hat1 = crop(recon_out1["x_hat"], (h, w)) | ||
218 | + | ||
219 | + with torch.no_grad(): | ||
220 | + out = net.compress(diff2) | ||
221 | + shape = out["shape"] | ||
222 | + strings = [] | ||
223 | + | ||
224 | + with Path(output).open("ab") as f: | ||
225 | + # write shape and number of encoded latents | ||
226 | + write_uints(f, (shape[0], shape[1], len(out["strings"]))) | ||
227 | + | ||
228 | + for s in out["strings"]: | ||
229 | + write_uints(f, (len(s[0]),)) | ||
230 | + write_bytes(f, s[0]) | ||
231 | + strings.append([s[0]]) | ||
232 | + | ||
233 | + with torch.no_grad(): | ||
234 | + recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | ||
235 | + x_hat2 = crop(recon_out["x_hat"], (h, w)) | ||
236 | + x_recon=ref+x_hat1-x_hat2 | ||
237 | + | ||
238 | + psnr=compute_psnr(x, x_recon) | ||
239 | + diff_img = torch2img(diff1) | ||
240 | + diff_img.save(path+"recon/v3_diff_1_"+str(ff)+"_q"+str(quality)+".png") | ||
241 | + diff_img = torch2img(diff2) | ||
242 | + diff_img.save(path+"recon/v3_diff_2_"+str(ff)+"_q"+str(quality)+".png") | ||
243 | + | ||
244 | + enc_time = time.time() - enc_start | ||
245 | + size = filesize(output) | ||
246 | + bpp = float(size) * 8 / (img.size[0] * img.size[1]*3) | ||
247 | + with Path(log_path).open("a") as f: | ||
248 | + f.write( f" {bpp-total_bpp:.4f} | " | ||
249 | + f" {psnr:.4f} |" | ||
250 | + f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n") | ||
251 | + recon_img = torch2img(x_recon) | ||
252 | + recon_img.save(path+"recon/v3_recon"+str(ff)+"_q"+str(quality)+".png") | ||
253 | + | ||
254 | + return psnr, bpp, x_recon, enc_time | ||
255 | + | ||
256 | + | ||
257 | +def _decode(inputpath, coder, show, frame, output=None): | ||
258 | + compressai.set_entropy_coder(coder) | ||
259 | + dec_start = time.time() | ||
260 | + | ||
261 | + with Path(inputpath).open("rb") as f: | ||
262 | + model, metric, quality = parse_header(read_uchars(f, 2)) | ||
263 | + print(f"Model: {model:s}, metric: {metric:s}, quality: {quality:d}") | ||
264 | + | ||
265 | + for i in range(frame): | ||
266 | + original_size = read_uints(f, 2) | ||
267 | + shape = read_uints(f, 2) | ||
268 | + strings = [] | ||
269 | + n_strings = read_uints(f, 1)[0] | ||
270 | + for _ in range(n_strings): | ||
271 | + s = read_bytes(f, read_uints(f, 1)[0]) | ||
272 | + strings.append([s]) | ||
273 | + | ||
274 | + start = time.time() | ||
275 | + net = models[model](quality=quality, metric=metric, pretrained=True).eval() | ||
276 | + load_time = time.time() - start | ||
277 | + | ||
278 | + with torch.no_grad(): | ||
279 | + out = net.decompress(strings, shape) | ||
280 | + | ||
281 | + x_hat = crop(out["x_hat"], original_size) | ||
282 | + img = torch2img(x_hat) | ||
283 | + dec_time = time.time() - dec_start | ||
284 | + print(f"Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)") | ||
285 | + | ||
286 | + if show: | ||
287 | + show_image(img) | ||
288 | + if output is not None: | ||
289 | + img.save(output+"_frame"+str(i)+".png") | ||
290 | + | ||
291 | +def show_image(img: Image.Image): | ||
292 | + from matplotlib import pyplot as plt | ||
293 | + | ||
294 | + fig, ax = plt.subplots() | ||
295 | + ax.axis("off") | ||
296 | + ax.title.set_text("Decoded image") | ||
297 | + ax.imshow(img) | ||
298 | + fig.tight_layout() | ||
299 | + plt.show() | ||
300 | + | ||
301 | + | ||
302 | +def encode(argv): | ||
303 | + parser = argparse.ArgumentParser(description="Encode image to bit-stream") | ||
304 | + parser.add_argument("image", type=str) | ||
305 | + parser.add_argument( | ||
306 | + "--model", | ||
307 | + choices=models.keys(), | ||
308 | + default=list(models.keys())[0], | ||
309 | + help="NN model to use (default: %(default)s)", | ||
310 | + ) | ||
311 | + parser.add_argument( | ||
312 | + "-m", | ||
313 | + "--metric", | ||
314 | + choices=["mse"], | ||
315 | + default="mse", | ||
316 | + help="metric trained against (default: %(default)s", | ||
317 | + ) | ||
318 | + parser.add_argument( | ||
319 | + "-q", | ||
320 | + "--quality", | ||
321 | + choices=list(range(1, 9)), | ||
322 | + type=int, | ||
323 | + default=3, | ||
324 | + help="Quality setting (default: %(default)s)", | ||
325 | + ) | ||
326 | + parser.add_argument( | ||
327 | + "-c", | ||
328 | + "--coder", | ||
329 | + choices=compressai.available_entropy_coders(), | ||
330 | + default=compressai.available_entropy_coders()[0], | ||
331 | + help="Entropy coder (default: %(default)s)", | ||
332 | + ) | ||
333 | + parser.add_argument( | ||
334 | + "-f", | ||
335 | + "--frame", | ||
336 | + type=int, | ||
337 | + default=100, | ||
338 | + help="Frame setting (default: %(default)s)", | ||
339 | + ) | ||
340 | + parser.add_argument( | ||
341 | + "-fr", | ||
342 | + "--framerate", | ||
343 | + choices=[60,50,24], | ||
344 | + type=int, | ||
345 | + default=50, | ||
346 | + help="Frame rate setting (default: %(default)s)", | ||
347 | + ) | ||
348 | + parser.add_argument( | ||
349 | + "-width", | ||
350 | + "--width", | ||
351 | + type=int, | ||
352 | + default=768, | ||
353 | + help="width setting (default: %(default))", | ||
354 | + ) | ||
355 | + parser.add_argument( | ||
356 | + "-hight", | ||
357 | + "--hight", | ||
358 | + type=int, | ||
359 | + default=768, | ||
360 | + help="hight setting (default: %(default))", | ||
361 | + ) | ||
362 | + parser.add_argument("-o", "--output", help="Output path") | ||
363 | + args = parser.parse_args(argv) | ||
364 | + path="examples/"+args.image+"/" | ||
365 | + if not args.output: | ||
366 | + #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin") | ||
367 | + args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v3.bin" | ||
368 | + log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v3.txt" | ||
369 | + | ||
370 | + header = get_header(args.model, args.metric, args.quality) | ||
371 | + with Path(args.output).open("wb") as f: | ||
372 | + write_uchars(f, header) | ||
373 | + write_uints(f, (args.hight, args.width)) | ||
374 | + | ||
375 | + with Path(log_path).open("w") as f: | ||
376 | + f.write(f"model : {args.model} | " | ||
377 | + f"quality : {args.quality} | " | ||
378 | + f"frames : {args.frame}\n") | ||
379 | + f.write( f"frame | bpp | " | ||
380 | + f" psnr |" | ||
381 | + f" Encoded time (model loading)\n" | ||
382 | + f" {0:3d} | ") | ||
383 | + | ||
384 | + total_psnr=0.0 | ||
385 | + total_bpp=0.0 | ||
386 | + total_time=0.0 | ||
387 | + args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" | ||
388 | + img=args.image+"_frame"+str(0)+".png" | ||
389 | + total_psnr, total_bpp, ref, total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) | ||
390 | + for ff in range(1, args.frame): | ||
391 | + with Path(log_path).open("a") as f: | ||
392 | + f.write(f" {ff:3d} | ") | ||
393 | + img=args.image+"_frame"+str(ff)+".png" | ||
394 | + | ||
395 | + psnr, total_bpp, ref, time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) | ||
396 | + total_psnr+=psnr | ||
397 | + total_time+=time | ||
398 | + | ||
399 | + total_psnr/=args.frame | ||
400 | + total_bpp/=args.frame | ||
401 | + | ||
402 | + with Path(log_path).open("a") as f: | ||
403 | + f.write( f"\n Total Encoded time: {total_time:.2f}s\n" | ||
404 | + f"\n Total PSNR: {total_psnr:.6f}\n" | ||
405 | + f" Total BPP: {total_bpp:.6f}\n") | ||
406 | + print(total_psnr) | ||
407 | + print(total_bpp) | ||
408 | + | ||
409 | + | ||
410 | +def decode(argv): | ||
411 | + parser = argparse.ArgumentParser(description="Decode bit-stream to imager") | ||
412 | + parser.add_argument("input", type=str) | ||
413 | + parser.add_argument( | ||
414 | + "-c", | ||
415 | + "--coder", | ||
416 | + choices=compressai.available_entropy_coders(), | ||
417 | + default=compressai.available_entropy_coders()[0], | ||
418 | + help="Entropy coder (default: %(default)s)", | ||
419 | + ) | ||
420 | + parser.add_argument( | ||
421 | + "-f", | ||
422 | + "--frame", | ||
423 | + choices=list(range(1, 600)), | ||
424 | + type=int, | ||
425 | + default=100, | ||
426 | + help="Frame setting (default: %(default)s)", | ||
427 | + ) | ||
428 | + parser.add_argument("--show", action="store_true") | ||
429 | + parser.add_argument("-o", "--output", help="Output path") | ||
430 | + args = parser.parse_args(argv) | ||
431 | + | ||
432 | + args.input="examples/"+args.input+"/"+args.input+"_768x768_"+str(args.frame//2)+"_8bit_444_v3.bin" | ||
433 | + args.output="examples/recon/"+args.output+"/"+args.output+"_768x768_"+str(50)+"_8bit_444" | ||
434 | + _decode(args.input, args.coder, args.show, args.frame, args.output) | ||
435 | + | ||
436 | + | ||
437 | +def parse_args(argv): | ||
438 | + parser = argparse.ArgumentParser(description="") | ||
439 | + parser.add_argument("command", choices=["encode", "decode"]) | ||
440 | + args = parser.parse_args(argv) | ||
441 | + return args | ||
442 | + | ||
443 | + | ||
444 | +def main(argv): | ||
445 | + args = parse_args(argv[1:2]) | ||
446 | + argv = argv[2:] | ||
447 | + torch.set_num_threads(1) # just to be sure | ||
448 | + if args.command == "encode": | ||
449 | + encode(argv) | ||
450 | + elif args.command == "decode": | ||
451 | + decode(argv) | ||
452 | + | ||
453 | + | ||
454 | +if __name__ == "__main__": | ||
455 | + main(sys.argv) |
-
Please register or login to post a comment